Source code for jax_dataloader.transform

"""Data transformation module for JAX applications."""

from typing import Any, Callable, Dict, List, Optional, Union
import jax.numpy as jnp

[docs] class Transform: """Base class for data transformations."""
[docs] def __init__(self): """Initialize the transform.""" self._transforms: List[Callable] = []
[docs] def add(self, transform: Callable): """Add a transform function. Args: transform: Transform function to add """ self._transforms.append(transform)
[docs] def __call__(self, data: Any) -> Any: """Apply all transforms to the data. Args: data: Data to transform Returns: Transformed data """ for transform in self._transforms: data = transform(data) return data
[docs] def apply(self, data: Any) -> Any: """Apply all transforms to the data. Args: data: Data to transform Returns: Transformed data """ return self(data)
[docs] def compose(self, other: 'Transform') -> 'Transform': """Compose this transform with another transform. Args: other: Transform to compose with Returns: New composed transform """ result = Transform() result._transforms = self._transforms + other._transforms return result
[docs] def chain(self, *transforms: Callable) -> 'Transform': """Chain multiple transforms together. Args: *transforms: Transforms to chain Returns: New transform with all transforms chained """ result = Transform() result._transforms = self._transforms + list(transforms) return result