Source code for jax_dataloader.transform
"""Data transformation module for JAX applications."""
from typing import Any, Callable, Dict, List, Optional, Union
import jax.numpy as jnp
[docs]
class Transform:
"""Base class for data transformations."""
[docs]
def add(self, transform: Callable):
"""Add a transform function.
Args:
transform: Transform function to add
"""
self._transforms.append(transform)
[docs]
def __call__(self, data: Any) -> Any:
"""Apply all transforms to the data.
Args:
data: Data to transform
Returns:
Transformed data
"""
for transform in self._transforms:
data = transform(data)
return data
[docs]
def apply(self, data: Any) -> Any:
"""Apply all transforms to the data.
Args:
data: Data to transform
Returns:
Transformed data
"""
return self(data)
[docs]
def compose(self, other: 'Transform') -> 'Transform':
"""Compose this transform with another transform.
Args:
other: Transform to compose with
Returns:
New composed transform
"""
result = Transform()
result._transforms = self._transforms + other._transforms
return result
[docs]
def chain(self, *transforms: Callable) -> 'Transform':
"""Chain multiple transforms together.
Args:
*transforms: Transforms to chain
Returns:
New transform with all transforms chained
"""
result = Transform()
result._transforms = self._transforms + list(transforms)
return result