Usage Guide
This guide covers the basic and advanced usage of JAX DataLoader.
Basic Usage
Loading Data
The basic way to load data is using the DataLoader class with a configuration:
from jax_dataloader import DataLoader, DataLoaderConfig
# Create a configuration
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
shuffle=True
)
# Create a data loader
loader = DataLoader(config)
# Iterate over batches
for batch_data, batch_labels in loader:
# Process your batch here
pass
Data Formats
JAX DataLoader supports various data formats:
CSV Data
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
target_column="label",
batch_size=32
)
JSON Data
config = DataLoaderConfig(
data_path="data/train.json",
loader_type="json",
target_key="label",
batch_size=32
)
Image Data
config = DataLoaderConfig(
data_path="data/images/",
loader_type="image",
image_size=(224, 224),
batch_size=32
)
Advanced Usage
Data Transformation
You can apply transformations to your data using the Transform class:
from jax_dataloader import Transform
# Create transformations
transform = Transform()
transform.add("random_flip", probability=0.5)
transform.add("random_rotation", max_angle=30)
# Add to configuration
config = DataLoaderConfig(
data_path="data/images/",
loader_type="image",
transform=transform,
batch_size=32
)
Memory Management
JAX DataLoader provides memory management features:
from jax_dataloader.memory import MemoryManager
# Create memory manager
memory_manager = MemoryManager(max_memory=1024**3) # 1GB
# Monitor memory usage
stats = memory_manager.monitor(interval=1.0)
print(f"Memory usage: {stats['current_usage']}")
Caching
You can use caching to improve performance:
from jax_dataloader.memory import Cache
# Create cache
cache = Cache(
max_size=1000,
eviction_policy="lru"
)
# Use cache in data loader
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
cache_size="1GB",
cache_policy="lru"
)
loader = DataLoader(config)
# Cache will automatically manage frequently accessed data
for epoch in range(num_epochs):
for batch_data, batch_labels in loader:
# Process your batch here
pass
Progress Tracking
Track the progress of data loading:
from jax_dataloader.progress import ProgressTracker
# Create progress tracker
tracker = ProgressTracker(
total=1000,
update_interval=0.1
)
# Use tracker in data loader
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
show_progress=True
)
loader = DataLoader(config)
# Get progress information
progress = loader.get_progress()
print(f"Progress: {progress['progress']:.2%}")
print(f"ETA: {progress['eta']:.2f} seconds")
Multi-GPU Support
JAX DataLoader supports multi-GPU training:
import jax
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32 * jax.device_count(), # Scale batch size by number of devices
shuffle=True
)
loader = DataLoader(config)
# Your training function
@jax.pmap
def train_step(params, batch):
# Your training logic here
pass
Best Practices
Batch Size Selection - Start with a small batch size and increase based on available memory - Use the calculate_batch_size utility function for optimal selection
Memory Management - Monitor memory usage with MemoryManager - Use caching for frequently accessed data - Enable memory optimization features when needed
Performance Optimization - Use appropriate number of workers - Enable prefetching for better performance - Use caching for repeated data access
Error Handling - Always check for data format compatibility - Handle memory errors gracefully - Use try-except blocks for data loading operations
For more examples and detailed information, see the Examples section.