Tutorials
This section provides step-by-step tutorials for common use cases with JAX DataLoader.
Getting Started with Image Classification
In this tutorial, we’ll create a complete image classification pipeline using JAX DataLoader.
First, let’s set up our environment:
import jax
import jax.numpy as jnp
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import ImageLoader
Create a data loader for our image dataset:
# Create image loader
loader = ImageLoader(
"path/to/image/dataset",
image_size=(224, 224),
normalize=True,
augment=True # Enable built-in augmentations
)
# Configure the dataloader
config = DataLoaderConfig(
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Create the dataloader
dataloader = DataLoader(
loader=loader,
config=config
)
Define a simple model:
def model(params, x):
# Simple CNN
x = jax.nn.relu(jnp.dot(x, params['w1']) + params['b1'])
x = jax.nn.relu(jnp.dot(x, params['w2']) + params['b2'])
return jnp.dot(x, params['w3']) + params['b3']
Training loop:
# Initialize parameters
params = {
'w1': jax.random.normal(jax.random.PRNGKey(0), (224*224*3, 128)),
'b1': jnp.zeros(128),
'w2': jax.random.normal(jax.random.PRNGKey(1), (128, 64)),
'b2': jnp.zeros(64),
'w3': jax.random.normal(jax.random.PRNGKey(2), (64, 10)),
'b3': jnp.zeros(10)
}
# Training loop
for epoch in range(num_epochs):
for batch_data, batch_labels in dataloader:
# Forward pass
predictions = model(params, batch_data)
# Compute loss
loss = jnp.mean((predictions - batch_labels) ** 2)
# Backward pass (using JAX's grad)
grads = jax.grad(lambda p: jnp.mean((model(p, batch_data) - batch_labels) ** 2))(params)
# Update parameters
params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
Large-Scale Data Processing
This tutorial demonstrates how to handle large datasets efficiently.
Set up memory-efficient data loading:
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import CSVLoader
# Create CSV loader with chunking
loader = CSVLoader(
"large_dataset.csv",
chunk_size=10000, # Process in chunks
target_column="target"
)
# Configure for memory efficiency
config = DataLoaderConfig(
batch_size=32,
memory_fraction=0.8,
auto_batch_size=True,
cache_size=1000,
num_workers=4
)
dataloader = DataLoader(
loader=loader,
config=config
)
Process data in batches:
# Enable memory optimization
dataloader.optimize_memory()
# Process data
for batch_data, batch_labels in dataloader:
# Process batch
process_batch(batch_data, batch_labels)
# Monitor memory usage
print(f"Memory usage: {dataloader.memory_manager.get_memory_usage()}")
Multi-GPU Training
Learn how to distribute training across multiple GPUs.
Set up multi-GPU configuration:
import jax
from jax_dataloader import DataLoader, DataLoaderConfig
# Get available devices
devices = jax.devices()
# Create sample data
data = jnp.arange(10000)
labels = jnp.arange(10000)
# Configure for multi-GPU
config = DataLoaderConfig(
batch_size=32,
num_devices=len(devices),
device_map="auto",
pin_memory=True
)
dataloader = DataLoader(
data=data,
labels=labels,
config=config
)
Implement distributed training:
# Training loop
for batch_data, batch_labels in dataloader:
# batch_data and batch_labels are already on the correct devices
data, device_id = batch_data
# Train on specific device
with jax.device(devices[device_id]):
# Your training code here
train_step(data, batch_labels)
Custom Data Augmentation
Learn how to create custom data augmentation pipelines.
Define augmentation functions:
import jax.random as random
import jax.numpy as jnp
def custom_augment(batch, key):
# Split key for multiple augmentations
key1, key2, key3 = random.split(key, 3)
# Add noise
noise = random.normal(key1, batch.shape) * 0.1
augmented = batch + noise
# Random rotation
angle = random.uniform(key2, minval=-0.1, maxval=0.1)
augmented = jnp.rot90(augmented, k=int(angle * 10))
# Random flip
if random.uniform(key3) > 0.5:
augmented = jnp.flip(augmented, axis=1)
return augmented
Apply custom augmentations:
from jax_dataloader import DataLoader, DataLoaderConfig
# Configure with custom augmentation
config = DataLoaderConfig(
batch_size=32,
transform=custom_augment,
transform_key=random.PRNGKey(0)
)
dataloader = DataLoader(
data=data,
config=config
)
Use in training:
for batch_data, batch_labels in dataloader:
# batch_data is already augmented
train_step(batch_data, batch_labels)
Advanced Caching Strategies
Learn how to optimize data loading with advanced caching.
Set up caching:
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.memory import Cache
# Create cache
cache = Cache(
max_size=1000, # Maximum number of batches to cache
eviction_policy="lru" # Least Recently Used
)
# Configure dataloader with cache
config = DataLoaderConfig(
batch_size=32,
cache=cache,
cache_hits=True # Track cache hits
)
dataloader = DataLoader(
data=data,
config=config
)
Monitor cache performance:
for batch_data, batch_labels in dataloader:
# Process batch
process_batch(batch_data, batch_labels)
# Print cache statistics
print(f"Cache hits: {dataloader.cache.hits}")
print(f"Cache misses: {dataloader.cache.misses}")
print(f"Hit rate: {dataloader.cache.hit_rate}")