API Reference
Core Classes
DataLoader
DataLoaderConfig
Data Loaders
CSVLoader
- class jax_dataloader.data.CSVLoader(data_path, target_column, feature_columns=None, chunk_size=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Bases:
BaseLoader
Loader for CSV data.
Examples
Basic CSV loading:
loader = CSVLoader( "data.csv", target_column="label", feature_columns=["feature1", "feature2"] )
Advanced CSV loading with chunking:
loader = CSVLoader( "large_dataset.csv", chunk_size=10000, target_column="target", feature_columns=["feature1", "feature2"], dtype=jnp.float32 )
Methods
- __init__(data_path, target_column, feature_columns=None, chunk_size=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the CSV loader.
- Parameters:
data_path (
str
) – Path to the CSV filetarget_column (
str
) – Name of the target columnfeature_columns (
Optional
[List
[str
]]) – List of feature column nameschunk_size (
Optional
[int
]) – Size of chunks to load at oncedtype (
Optional
[dtype
]) – Optional data type for arraysbatch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the CSV file.
- Parameters:
data_path (
str
) – Path to the CSV file- Return type:
Any
- Returns:
Loaded data
- get_chunk(start, size)[source]
Get a chunk of data from the CSV file.
- Parameters:
start (int) – The starting index of the chunk.
size (int) – The size of the chunk.
- Returns:
A tuple containing the chunk data and labels as numpy arrays.
- Return type:
tuple
- get_metadata()
Get metadata about the data loader.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing metadata about the data loader
- __init__(data_path, target_column, feature_columns=None, chunk_size=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the CSV loader.
- Parameters:
data_path (
str
) – Path to the CSV filetarget_column (
str
) – Name of the target columnfeature_columns (
Optional
[List
[str
]]) – List of feature column nameschunk_size (
Optional
[int
]) – Size of chunks to load at oncedtype (
Optional
[dtype
]) – Optional data type for arraysbatch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the CSV file.
- Parameters:
data_path (
str
) – Path to the CSV file- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
- __len__()[source]
Return the number of batches.
- Return type:
int
- __iter__()[source]
Return an iterator over the data.
- get_chunk(start, size)[source]
Get a chunk of data from the CSV file.
- Parameters:
start (int) – The starting index of the chunk.
size (int) – The size of the chunk.
- Returns:
A tuple containing the chunk data and labels as numpy arrays.
- Return type:
tuple
JSONLoader
- class jax_dataloader.data.JSONLoader(data_path, data_key='data', label_key='labels', preprocess_fn=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Bases:
BaseLoader
Loader for JSON data.
Examples
Basic JSON loading:
loader = JSONLoader( "data.json", data_key="features", label_key="labels" )
Advanced JSON loading with preprocessing:
loader = JSONLoader( "data.json", data_key="features", label_key="labels", preprocess_fn=lambda x: x / 255.0, dtype=jnp.float32 )
Methods
- __init__(data_path, data_key='data', label_key='labels', preprocess_fn=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the JSON loader.
- Parameters:
data_path (
str
) – Path to the JSON filedata_key (
str
) – Key for data in JSONlabel_key (
str
) – Key for labels in JSONpreprocess_fn (
Optional
[Callable
]) – Optional preprocessing functiondtype (
Optional
[dtype
]) – Optional data type for arraysbatch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the JSON file.
- Parameters:
data_path (
str
) – Path to the JSON file- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
- get_metadata()
Get metadata about the data loader.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing metadata about the data loader
- __init__(data_path, data_key='data', label_key='labels', preprocess_fn=None, dtype=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the JSON loader.
- Parameters:
data_path (
str
) – Path to the JSON filedata_key (
str
) – Key for data in JSONlabel_key (
str
) – Key for labels in JSONpreprocess_fn (
Optional
[Callable
]) – Optional preprocessing functiondtype (
Optional
[dtype
]) – Optional data type for arraysbatch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the JSON file.
- Parameters:
data_path (
str
) – Path to the JSON file- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
- __len__()[source]
Return the number of batches.
- Return type:
int
- __iter__()[source]
Return an iterator over the data.
ImageLoader
- class jax_dataloader.data.ImageLoader(data_path, image_size=(224, 224), normalize=True, augment=False, augment_options=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Bases:
BaseLoader
Loader for image data.
Examples
Basic image loading:
loader = ImageLoader( "image_directory", image_size=(224, 224), normalize=True )
Advanced image loading with augmentation:
loader = ImageLoader( "image_directory", image_size=(224, 224), normalize=True, augment=True, augment_options={ "rotation": [-30, 30], "flip": True, "brightness": [0.8, 1.2] } )
Methods
- __init__(data_path, image_size=(224, 224), normalize=True, augment=False, augment_options=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the image loader.
- Parameters:
data_path (
str
) – Path to the image directoryimage_size (
Tuple
[int
,int
]) – Target size for imagesnormalize (
bool
) – Whether to normalize pixel valuesaugment (
bool
) – Whether to apply data augmentationaugment_options (
Optional
[Dict
[str
,Any
]]) – Options for data augmentationbatch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the image directory.
- Parameters:
data_path (
str
) – Path to the image directory- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
- augment(data)[source]
Apply data augmentation to the image.
- Parameters:
data (
Any
) – Image data to augment- Return type:
Any
- Returns:
Augmented image data
- get_metadata()
Get metadata about the data loader.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing metadata about the data loader
- __init__(data_path, image_size=(224, 224), normalize=True, augment=False, augment_options=None, batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the image loader.
- Parameters:
data_path (
str
) – Path to the image directoryimage_size (
Tuple
[int
,int
]) – Target size for imagesnormalize (
bool
) – Whether to normalize pixel valuesaugment (
bool
) – Whether to apply data augmentationaugment_options (
Optional
[Dict
[str
,Any
]]) – Options for data augmentationbatch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the image directory.
- Parameters:
data_path (
str
) – Path to the image directory- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
- augment(data)[source]
Apply data augmentation to the image.
- Parameters:
data (
Any
) – Image data to augment- Return type:
Any
- Returns:
Augmented image data
- __len__()[source]
Return the number of batches.
- Return type:
int
- __iter__()[source]
Return an iterator over the data.
BaseLoader
- class jax_dataloader.data.BaseLoader(batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Bases:
ABC
Base class for all data loaders.
Methods
- __init__(batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the data loader.
- Parameters:
batch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- load(data_path)[source]
Load data from the specified path.
- Parameters:
data_path (
str
) – Path to the data file- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
- get_metadata()[source]
Get metadata about the data loader.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing metadata about the data loader
- __init__(batch_size=32, shuffle=True, seed=None, num_workers=0, prefetch_factor=2)[source]
Initialize the data loader.
- Parameters:
batch_size (
int
) – Number of samples per batchshuffle (
bool
) – Whether to shuffle the dataseed (
Optional
[int
]) – Random seed for reproducibilitynum_workers (
int
) – Number of worker processesprefetch_factor (
int
) – Number of batches to prefetch
- abstractmethod __len__()[source]
Return the number of batches.
- Return type:
int
- abstractmethod __iter__()[source]
Return an iterator over the data.
- __next__()[source]
Get the next batch of data.
- get_metadata()[source]
Get metadata about the data loader.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing metadata about the data loader
- load(data_path)[source]
Load data from the specified path.
- Parameters:
data_path (
str
) – Path to the data file- Return type:
Any
- Returns:
Loaded data
- preprocess(data)[source]
Preprocess the loaded data.
- Parameters:
data (
Any
) – Data to preprocess- Return type:
Any
- Returns:
Preprocessed data
Memory Management
MemoryManager
- class jax_dataloader.memory.MemoryManager(max_memory=None)[source]
Bases:
object
Manages memory allocation and deallocation.
Examples
Basic memory management:
manager = MemoryManager(max_memory=1024**3) # 1GB
Advanced memory management with monitoring:
manager = MemoryManager(max_memory=1024**3) stats = manager.monitor(interval=1.0) print(f"Memory usage: {stats['current_usage']}")
Methods
- __init__(max_memory=None)[source]
Initialize the memory manager.
- Parameters:
max_memory (
Optional
[float
]) – Maximum memory to allocate in bytes
- allocate(size)[source]
Allocate memory.
- Parameters:
size (
int
) – Size to allocate in bytes- Return type:
bool
- Returns:
True if allocation was successful
- free(size=None)[source]
Free memory.
- Parameters:
size (
Optional
[int
]) – Optional size to free in bytes
- get_usage()[source]
Get current memory usage statistics.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing memory usage statistics
- monitor(interval=1.0)[source]
Monitor memory usage over time.
- Parameters:
interval (
float
) – Time interval between updates in seconds- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing runtime statistics
- __init__(max_memory=None)[source]
Initialize the memory manager.
- Parameters:
max_memory (
Optional
[float
]) – Maximum memory to allocate in bytes
- allocate(size)[source]
Allocate memory.
- Parameters:
size (
int
) – Size to allocate in bytes- Return type:
bool
- Returns:
True if allocation was successful
- deallocate(size)[source]
Deallocate memory.
- Parameters:
size (
int
) – Size to deallocate in bytes
- free(size=None)[source]
Free memory.
- Parameters:
size (
Optional
[int
]) – Optional size to free in bytes
- get_usage()[source]
Get current memory usage statistics.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing memory usage statistics
- cleanup()[source]
Clean up memory and reset statistics.
- monitor(interval=1.0)[source]
Monitor memory usage over time.
- Parameters:
interval (
float
) – Time interval between updates in seconds- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing runtime statistics
Cache
- class jax_dataloader.memory.Cache(max_size, eviction_policy='lru', track_stats=True, max_age=None)[source]
Bases:
object
Cache for storing data in memory.
Examples
Basic caching:
cache = Cache( max_size=1000, eviction_policy="lru" )
Advanced caching with statistics:
cache = Cache( max_size=1000, eviction_policy="lru", track_stats=True, max_age=3600 # 1 hour )
Methods
- __init__(max_size, eviction_policy='lru', track_stats=True, max_age=None)[source]
Initialize the cache.
- Parameters:
max_size (
int
) – Maximum size of the cache in byteseviction_policy (
str
) – Cache eviction policy (“lru” or “fifo”)track_stats (
bool
) – Whether to track cache statisticsmax_age (
Optional
[float
]) – Maximum age of cached items in seconds
- get(key)[source]
Get a value from the cache.
- Parameters:
key (
str
) – Cache key- Return type:
Optional
[Any
]- Returns:
Cached value if found, None otherwise
- put(key, value)[source]
Put a value in the cache.
- Parameters:
key (
str
) – Cache keyvalue (
Any
) – Value to cache
- get_stats()[source]
Get cache statistics.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing cache statistics
- __init__(max_size, eviction_policy='lru', track_stats=True, max_age=None)[source]
Initialize the cache.
- Parameters:
max_size (
int
) – Maximum size of the cache in byteseviction_policy (
str
) – Cache eviction policy (“lru” or “fifo”)track_stats (
bool
) – Whether to track cache statisticsmax_age (
Optional
[float
]) – Maximum age of cached items in seconds
- get(key)[source]
Get a value from the cache.
- Parameters:
key (
str
) – Cache key- Return type:
Optional
[Any
]- Returns:
Cached value if found, None otherwise
- put(key, value)[source]
Put a value in the cache.
- Parameters:
key (
str
) – Cache keyvalue (
Any
) – Value to cache
- clear()[source]
Clear the cache.
- get_stats()[source]
Get cache statistics.
- Return type:
Dict
[str
,Any
]- Returns:
Dictionary containing cache statistics
- evict(key)[source]
Evict a specific key from the cache.
- Parameters:
key (
str
) – Key to evict
- _estimate_size(value)[source]
Estimate the size of a value in bytes.
- Parameters:
value (
Any
) – Value to estimate size of- Return type:
int
- Returns:
Estimated size in bytes
- _evict()[source]
Evict an item based on the eviction policy.
Progress Tracking
ProgressTracker
- class jax_dataloader.progress.ProgressTracker(total, desc=None, unit='it', leave=True, update_interval=0.1, callbacks=None, show_eta=True)[source]
Bases:
object
Tracks progress of data loading.
Examples
Basic progress tracking:
tracker = ProgressTracker( total=1000, update_interval=0.1 )
Advanced progress tracking with callbacks:
def on_update(progress): print(f"Progress: {progress:.1%}") tracker = ProgressTracker( total=1000, update_interval=0.1, callbacks=[on_update], show_eta=True )
Methods
- __init__(total, desc=None, unit='it', leave=True, update_interval=0.1, callbacks=None, show_eta=True)[source]
Initialize the progress tracker.
- Parameters:
total (
int
) – Total number of itemsdesc (
Optional
[str
]) – Description of the progressunit (
str
) – Unit of progressleave (
bool
) – Whether to leave the progress barupdate_interval (
float
) – Time interval between updates in secondscallbacks (
Optional
[List
[Callable
[[float
],None
]]]) – List of callback functions to call on updatesshow_eta (
bool
) – Whether to show estimated time remaining
- get_progress()[source]
Get the current progress as a fraction.
- Return type:
float
- Returns:
Progress as a fraction between 0 and 1
- get_eta()[source]
Get the estimated time remaining.
- Return type:
float
- Returns:
Estimated time remaining in seconds
- __init__(total, desc=None, unit='it', leave=True, update_interval=0.1, callbacks=None, show_eta=True)[source]
Initialize the progress tracker.
- Parameters:
total (
int
) – Total number of itemsdesc (
Optional
[str
]) – Description of the progressunit (
str
) – Unit of progressleave (
bool
) – Whether to leave the progress barupdate_interval (
float
) – Time interval between updates in secondscallbacks (
Optional
[List
[Callable
[[float
],None
]]]) – List of callback functions to call on updatesshow_eta (
bool
) – Whether to show estimated time remaining
- update(n=1)[source]
Update the progress.
- Parameters:
n (
int
) – Number of items to update
- reset()[source]
Reset the progress tracker.
- get_progress()[source]
Get the current progress as a fraction.
- Return type:
float
- Returns:
Progress as a fraction between 0 and 1
- get_eta()[source]
Get the estimated time remaining.
- Return type:
float
- Returns:
Estimated time remaining in seconds
- close()[source]
Close the progress bar.
- __enter__()[source]
Context manager entry.
- __exit__(exc_type, exc_val, exc_tb)[source]
Context manager exit.
Data Augmentation
Transform
- class jax_dataloader.transform.Transform[source]
Bases:
object
Base class for data transformations.
Examples
Basic transformation:
transform = Transform() transform.add(lambda x: x * 2)
Advanced transformation with chaining:
transform = Transform() transform.add(lambda x: x * 2) transform.add(lambda x: x + 1) transform.add(lambda x: jnp.clip(x, 0, 1))
Methods
- add(transform)[source]
Add a transform function.
- Parameters:
transform (
Callable
) – Transform function to add
- apply(data)[source]
Apply all transforms to the data.
- Parameters:
data (
Any
) – Data to transform- Return type:
Any
- Returns:
Transformed data
- compose(other)[source]
Compose this transform with another transform.
- Parameters:
other (
Transform
) – Transform to compose with- Return type:
Transform
- Returns:
New composed transform
- chain(*transforms)[source]
Chain multiple transforms together.
- Parameters:
*transforms (
Callable
) – Transforms to chain- Return type:
Transform
- Returns:
New transform with all transforms chained
- __init__()[source]
Initialize the transform.
- add(transform)[source]
Add a transform function.
- Parameters:
transform (
Callable
) – Transform function to add
- __call__(data)[source]
Apply all transforms to the data.
- Parameters:
data (
Any
) – Data to transform- Return type:
Any
- Returns:
Transformed data
- apply(data)[source]
Apply all transforms to the data.
- Parameters:
data (
Any
) – Data to transform- Return type:
Any
- Returns:
Transformed data
- compose(other)[source]
Compose this transform with another transform.
- Parameters:
other (
Transform
) – Transform to compose with- Return type:
Transform
- Returns:
New composed transform
- chain(*transforms)[source]
Chain multiple transforms together.
- Parameters:
*transforms (
Callable
) – Transforms to chain- Return type:
Transform
- Returns:
New transform with all transforms chained
Exceptions
DataLoaderError
- exception jax_dataloader.exceptions.DataLoaderError[source]
Bases:
Exception
Base exception for data loader errors.
ConfigurationError
- exception jax_dataloader.exceptions.ConfigurationError[source]
Bases:
DataLoaderError
Exception raised for configuration errors.
MemoryError
- exception jax_dataloader.exceptions.MemoryError[source]
Bases:
DataLoaderError
Exception raised for memory-related errors.
Utility Functions
- jax_dataloader.utils.format_size(size)[source]
Format a size in bytes to a human-readable string.
- Parameters:
size (
Union
[int
,float
]) – Size in bytes- Return type:
str
- Returns:
Human-readable size string