API Reference

Core Classes

DataLoader

DataLoaderConfig

Data Loaders

CSVLoader

class jax_dataloader.data.CSVLoader(data_path, target_column, feature_columns=None, chunk_size=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Bases: BaseLoader

Loader for CSV data.

Examples

Basic CSV loading:

loader = CSVLoader(
    "data.csv",
    target_column="label",
    feature_columns=["feature1", "feature2"]
)

Advanced CSV loading with chunking:

loader = CSVLoader(
    "large_dataset.csv",
    chunk_size=10000,
    target_column="target",
    feature_columns=["feature1", "feature2"],
    dtype=jnp.float32
)

Methods

__init__(data_path, target_column, feature_columns=None, chunk_size=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the CSV loader.

Parameters:
  • data_path (str) – Path to the CSV file

  • target_column (str) – Name of the target column

  • feature_columns (Optional[List[str]]) – List of feature column names

  • chunk_size (Optional[int]) – Size of chunks to load at once

  • dtype (Optional[dtype]) – Optional data type for arrays

  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the CSV file.

Parameters:

data_path (str) – Path to the CSV file

Return type:

Any

Returns:

Loaded data

get_chunk(start, size)[source]

Get a chunk of data from the CSV file.

Parameters:
  • start (int) – The starting index of the chunk.

  • size (int) – The size of the chunk.

Returns:

A tuple containing the chunk data and labels as numpy arrays.

Return type:

tuple

get_metadata()

Get metadata about the data loader.

Return type:

Dict[str, Any]

Returns:

Dictionary containing metadata about the data loader

__init__(data_path, target_column, feature_columns=None, chunk_size=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the CSV loader.

Parameters:
  • data_path (str) – Path to the CSV file

  • target_column (str) – Name of the target column

  • feature_columns (Optional[List[str]]) – List of feature column names

  • chunk_size (Optional[int]) – Size of chunks to load at once

  • dtype (Optional[dtype]) – Optional data type for arrays

  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the CSV file.

Parameters:

data_path (str) – Path to the CSV file

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

__len__()[source]

Return the number of batches.

Return type:

int

__iter__()[source]

Return an iterator over the data.

get_chunk(start, size)[source]

Get a chunk of data from the CSV file.

Parameters:
  • start (int) – The starting index of the chunk.

  • size (int) – The size of the chunk.

Returns:

A tuple containing the chunk data and labels as numpy arrays.

Return type:

tuple

JSONLoader

class jax_dataloader.data.JSONLoader(data_path, data_key='data', label_key='labels', preprocess_fn=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Bases: BaseLoader

Loader for JSON data.

Examples

Basic JSON loading:

loader = JSONLoader(
    "data.json",
    data_key="features",
    label_key="labels"
)

Advanced JSON loading with preprocessing:

loader = JSONLoader(
    "data.json",
    data_key="features",
    label_key="labels",
    preprocess_fn=lambda x: x / 255.0,
    dtype=jnp.float32
)

Methods

__init__(data_path, data_key='data', label_key='labels', preprocess_fn=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the JSON loader.

Parameters:
  • data_path (str) – Path to the JSON file

  • data_key (str) – Key for data in JSON

  • label_key (str) – Key for labels in JSON

  • preprocess_fn (Optional[Callable]) – Optional preprocessing function

  • dtype (Optional[dtype]) – Optional data type for arrays

  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the JSON file.

Parameters:

data_path (str) – Path to the JSON file

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

get_metadata()

Get metadata about the data loader.

Return type:

Dict[str, Any]

Returns:

Dictionary containing metadata about the data loader

__init__(data_path, data_key='data', label_key='labels', preprocess_fn=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the JSON loader.

Parameters:
  • data_path (str) – Path to the JSON file

  • data_key (str) – Key for data in JSON

  • label_key (str) – Key for labels in JSON

  • preprocess_fn (Optional[Callable]) – Optional preprocessing function

  • dtype (Optional[dtype]) – Optional data type for arrays

  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the JSON file.

Parameters:

data_path (str) – Path to the JSON file

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

__len__()[source]

Return the number of batches.

Return type:

int

__iter__()[source]

Return an iterator over the data.

ImageLoader

class jax_dataloader.data.ImageLoader(data_path, image_size=(224, 224), normalize=True, augment=False, augment_options=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Bases: BaseLoader

Loader for image data.

Examples

Basic image loading:

loader = ImageLoader(
    "image_directory",
    image_size=(224, 224),
    normalize=True
)

Advanced image loading with augmentation:

loader = ImageLoader(
    "image_directory",
    image_size=(224, 224),
    normalize=True,
    augment=True,
    augment_options={
        "rotation": [-30, 30],
        "flip": True,
        "brightness": [0.8, 1.2]
    }
)

Methods

__init__(data_path, image_size=(224, 224), normalize=True, augment=False, augment_options=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the image loader.

Parameters:
  • data_path (str) – Path to the image directory

  • image_size (Tuple[int, int]) – Target size for images

  • normalize (bool) – Whether to normalize pixel values

  • augment (bool) – Whether to apply data augmentation

  • augment_options (Optional[Dict[str, Any]]) – Options for data augmentation

  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the image directory.

Parameters:

data_path (str) – Path to the image directory

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

augment(data)[source]

Apply data augmentation to the image.

Parameters:

data (Any) – Image data to augment

Return type:

Any

Returns:

Augmented image data

get_metadata()

Get metadata about the data loader.

Return type:

Dict[str, Any]

Returns:

Dictionary containing metadata about the data loader

__init__(data_path, image_size=(224, 224), normalize=True, augment=False, augment_options=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the image loader.

Parameters:
  • data_path (str) – Path to the image directory

  • image_size (Tuple[int, int]) – Target size for images

  • normalize (bool) – Whether to normalize pixel values

  • augment (bool) – Whether to apply data augmentation

  • augment_options (Optional[Dict[str, Any]]) – Options for data augmentation

  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the image directory.

Parameters:

data_path (str) – Path to the image directory

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

augment(data)[source]

Apply data augmentation to the image.

Parameters:

data (Any) – Image data to augment

Return type:

Any

Returns:

Augmented image data

__len__()[source]

Return the number of batches.

Return type:

int

__iter__()[source]

Return an iterator over the data.

BaseLoader

class jax_dataloader.data.BaseLoader(batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Bases: ABC

Base class for all data loaders.

Methods

__init__(batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the data loader.

Parameters:
  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

load(data_path)[source]

Load data from the specified path.

Parameters:

data_path (str) – Path to the data file

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

get_metadata()[source]

Get metadata about the data loader.

Return type:

Dict[str, Any]

Returns:

Dictionary containing metadata about the data loader

__init__(batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]

Initialize the data loader.

Parameters:
  • batch_size (int) – Number of samples per batch

  • shuffle (bool) – Whether to shuffle the data

  • seed (Optional[int]) – Random seed for reproducibility

  • num_workers (int) – Number of worker processes

  • prefetch_factor (int) – Number of batches to prefetch

abstractmethod __len__()[source]

Return the number of batches.

Return type:

int

abstractmethod __iter__()[source]

Return an iterator over the data.

__next__()[source]

Get the next batch of data.

get_metadata()[source]

Get metadata about the data loader.

Return type:

Dict[str, Any]

Returns:

Dictionary containing metadata about the data loader

load(data_path)[source]

Load data from the specified path.

Parameters:

data_path (str) – Path to the data file

Return type:

Any

Returns:

Loaded data

preprocess(data)[source]

Preprocess the loaded data.

Parameters:

data (Any) – Data to preprocess

Return type:

Any

Returns:

Preprocessed data

Memory Management

MemoryManager

class jax_dataloader.memory.MemoryManager(max_memory=None)[source]

Bases: object

Manages memory allocation and deallocation.

Examples

Basic memory management:

manager = MemoryManager(max_memory=1024**3)  # 1GB

Advanced memory management with monitoring:

manager = MemoryManager(max_memory=1024**3)
stats = manager.monitor(interval=1.0)
print(f"Memory usage: {stats['current_usage']}")

Methods

__init__(max_memory=None)[source]

Initialize the memory manager.

Parameters:

max_memory (Optional[float]) – Maximum memory to allocate in bytes

allocate(size)[source]

Allocate memory.

Parameters:

size (int) – Size to allocate in bytes

Return type:

bool

Returns:

True if allocation was successful

deallocate(size)[source]

Deallocate memory.

Parameters:

size (int) – Size to deallocate in bytes

free(size=None)[source]

Free memory.

Parameters:

size (Optional[int]) – Optional size to free in bytes

get_usage()[source]

Get current memory usage statistics.

Return type:

Dict[str, Any]

Returns:

Dictionary containing memory usage statistics

cleanup()[source]

Clean up memory and reset statistics.

monitor(interval=1.0)[source]

Monitor memory usage over time.

Parameters:

interval (float) – Time interval between updates in seconds

Return type:

Dict[str, Any]

Returns:

Dictionary containing runtime statistics

__init__(max_memory=None)[source]

Initialize the memory manager.

Parameters:

max_memory (Optional[float]) – Maximum memory to allocate in bytes

allocate(size)[source]

Allocate memory.

Parameters:

size (int) – Size to allocate in bytes

Return type:

bool

Returns:

True if allocation was successful

deallocate(size)[source]

Deallocate memory.

Parameters:

size (int) – Size to deallocate in bytes

free(size=None)[source]

Free memory.

Parameters:

size (Optional[int]) – Optional size to free in bytes

get_usage()[source]

Get current memory usage statistics.

Return type:

Dict[str, Any]

Returns:

Dictionary containing memory usage statistics

cleanup()[source]

Clean up memory and reset statistics.

monitor(interval=1.0)[source]

Monitor memory usage over time.

Parameters:

interval (float) – Time interval between updates in seconds

Return type:

Dict[str, Any]

Returns:

Dictionary containing runtime statistics

Cache

class jax_dataloader.memory.Cache(max_size, eviction_policy='lru', track_stats=True, max_age=None)[source]

Bases: object

Cache for storing data in memory.

Examples

Basic caching:

cache = Cache(
    max_size=1000,
    eviction_policy="lru"
)

Advanced caching with statistics:

cache = Cache(
    max_size=1000,
    eviction_policy="lru",
    track_stats=True,
    max_age=3600  # 1 hour
)

Methods

__init__(max_size, eviction_policy='lru', track_stats=True, max_age=None)[source]

Initialize the cache.

Parameters:
  • max_size (int) – Maximum size of the cache in bytes

  • eviction_policy (str) – Cache eviction policy (“lru” or “fifo”)

  • track_stats (bool) – Whether to track cache statistics

  • max_age (Optional[float]) – Maximum age of cached items in seconds

get(key)[source]

Get a value from the cache.

Parameters:

key (str) – Cache key

Return type:

Optional[Any]

Returns:

Cached value if found, None otherwise

put(key, value)[source]

Put a value in the cache.

Parameters:
  • key (str) – Cache key

  • value (Any) – Value to cache

clear()[source]

Clear the cache.

get_stats()[source]

Get cache statistics.

Return type:

Dict[str, Any]

Returns:

Dictionary containing cache statistics

evict(key)[source]

Evict a specific key from the cache.

Parameters:

key (str) – Key to evict

__init__(max_size, eviction_policy='lru', track_stats=True, max_age=None)[source]

Initialize the cache.

Parameters:
  • max_size (int) – Maximum size of the cache in bytes

  • eviction_policy (str) – Cache eviction policy (“lru” or “fifo”)

  • track_stats (bool) – Whether to track cache statistics

  • max_age (Optional[float]) – Maximum age of cached items in seconds

get(key)[source]

Get a value from the cache.

Parameters:

key (str) – Cache key

Return type:

Optional[Any]

Returns:

Cached value if found, None otherwise

put(key, value)[source]

Put a value in the cache.

Parameters:
  • key (str) – Cache key

  • value (Any) – Value to cache

clear()[source]

Clear the cache.

get_stats()[source]

Get cache statistics.

Return type:

Dict[str, Any]

Returns:

Dictionary containing cache statistics

evict(key)[source]

Evict a specific key from the cache.

Parameters:

key (str) – Key to evict

_estimate_size(value)[source]

Estimate the size of a value in bytes.

Parameters:

value (Any) – Value to estimate size of

Return type:

int

Returns:

Estimated size in bytes

_evict()[source]

Evict an item based on the eviction policy.

Progress Tracking

ProgressTracker

class jax_dataloader.progress.ProgressTracker(total, desc=None, unit='it', leave=True, update_interval=0.1, callbacks=None, show_eta=True)[source]

Bases: object

Tracks progress of data loading.

Examples

Basic progress tracking:

tracker = ProgressTracker(
    total=1000,
    update_interval=0.1
)

Advanced progress tracking with callbacks:

def on_update(progress):
    print(f"Progress: {progress:.1%}")

tracker = ProgressTracker(
    total=1000,
    update_interval=0.1,
    callbacks=[on_update],
    show_eta=True
)

Methods

__init__(total, desc=None, unit='it', leave=True, update_interval=0.1, callbacks=None, show_eta=True)[source]

Initialize the progress tracker.

Parameters:
  • total (int) – Total number of items

  • desc (Optional[str]) – Description of the progress

  • unit (str) – Unit of progress

  • leave (bool) – Whether to leave the progress bar

  • update_interval (float) – Time interval between updates in seconds

  • callbacks (Optional[List[Callable[[float], None]]]) – List of callback functions to call on updates

  • show_eta (bool) – Whether to show estimated time remaining

update(n=1)[source]

Update the progress.

Parameters:

n (int) – Number of items to update

reset()[source]

Reset the progress tracker.

get_progress()[source]

Get the current progress as a fraction.

Return type:

float

Returns:

Progress as a fraction between 0 and 1

get_eta()[source]

Get the estimated time remaining.

Return type:

float

Returns:

Estimated time remaining in seconds

__init__(total, desc=None, unit='it', leave=True, update_interval=0.1, callbacks=None, show_eta=True)[source]

Initialize the progress tracker.

Parameters:
  • total (int) – Total number of items

  • desc (Optional[str]) – Description of the progress

  • unit (str) – Unit of progress

  • leave (bool) – Whether to leave the progress bar

  • update_interval (float) – Time interval between updates in seconds

  • callbacks (Optional[List[Callable[[float], None]]]) – List of callback functions to call on updates

  • show_eta (bool) – Whether to show estimated time remaining

update(n=1)[source]

Update the progress.

Parameters:

n (int) – Number of items to update

reset()[source]

Reset the progress tracker.

get_progress()[source]

Get the current progress as a fraction.

Return type:

float

Returns:

Progress as a fraction between 0 and 1

get_eta()[source]

Get the estimated time remaining.

Return type:

float

Returns:

Estimated time remaining in seconds

close()[source]

Close the progress bar.

__enter__()[source]

Context manager entry.

__exit__(exc_type, exc_val, exc_tb)[source]

Context manager exit.

Data Augmentation

Transform

class jax_dataloader.transform.Transform[source]

Bases: object

Base class for data transformations.

Examples

Basic transformation:

transform = Transform()
transform.add(lambda x: x * 2)

Advanced transformation with chaining:

transform = Transform()
transform.add(lambda x: x * 2)
transform.add(lambda x: x + 1)
transform.add(lambda x: jnp.clip(x, 0, 1))

Methods

__init__()[source]

Initialize the transform.

add(transform)[source]

Add a transform function.

Parameters:

transform (Callable) – Transform function to add

apply(data)[source]

Apply all transforms to the data.

Parameters:

data (Any) – Data to transform

Return type:

Any

Returns:

Transformed data

compose(other)[source]

Compose this transform with another transform.

Parameters:

other (Transform) – Transform to compose with

Return type:

Transform

Returns:

New composed transform

chain(*transforms)[source]

Chain multiple transforms together.

Parameters:

*transforms (Callable) – Transforms to chain

Return type:

Transform

Returns:

New transform with all transforms chained

__init__()[source]

Initialize the transform.

add(transform)[source]

Add a transform function.

Parameters:

transform (Callable) – Transform function to add

__call__(data)[source]

Apply all transforms to the data.

Parameters:

data (Any) – Data to transform

Return type:

Any

Returns:

Transformed data

apply(data)[source]

Apply all transforms to the data.

Parameters:

data (Any) – Data to transform

Return type:

Any

Returns:

Transformed data

compose(other)[source]

Compose this transform with another transform.

Parameters:

other (Transform) – Transform to compose with

Return type:

Transform

Returns:

New composed transform

chain(*transforms)[source]

Chain multiple transforms together.

Parameters:

*transforms (Callable) – Transforms to chain

Return type:

Transform

Returns:

New transform with all transforms chained

Exceptions

DataLoaderError

exception jax_dataloader.exceptions.DataLoaderError[source]

Bases: Exception

Base exception for data loader errors.

ConfigurationError

exception jax_dataloader.exceptions.ConfigurationError[source]

Bases: DataLoaderError

Exception raised for configuration errors.

MemoryError

exception jax_dataloader.exceptions.MemoryError[source]

Bases: DataLoaderError

Exception raised for memory-related errors.

Utility Functions

jax_dataloader.utils.format_size(size)[source]

Format a size in bytes to a human-readable string.

Parameters:

size (Union[int, float]) – Size in bytes

Return type:

str

Returns:

Human-readable size string