Tutorials

This section provides step-by-step tutorials for common use cases with JAX DataLoader.

Getting Started with Image Classification

In this tutorial, we’ll create a complete image classification pipeline using JAX DataLoader.

  1. First, let’s set up our environment:

import jax
import jax.numpy as jnp
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import ImageLoader
  1. Create a data loader for our image dataset:

# Create image loader
loader = ImageLoader(
    "path/to/image/dataset",
    image_size=(224, 224),
    normalize=True,
    augment=True  # Enable built-in augmentations
)

# Configure the dataloader
config = DataLoaderConfig(
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# Create the dataloader
dataloader = DataLoader(
    loader=loader,
    config=config
)
  1. Define a simple model:

def model(params, x):
    # Simple CNN
    x = jax.nn.relu(jnp.dot(x, params['w1']) + params['b1'])
    x = jax.nn.relu(jnp.dot(x, params['w2']) + params['b2'])
    return jnp.dot(x, params['w3']) + params['b3']
  1. Training loop:

# Initialize parameters
params = {
    'w1': jax.random.normal(jax.random.PRNGKey(0), (224*224*3, 128)),
    'b1': jnp.zeros(128),
    'w2': jax.random.normal(jax.random.PRNGKey(1), (128, 64)),
    'b2': jnp.zeros(64),
    'w3': jax.random.normal(jax.random.PRNGKey(2), (64, 10)),
    'b3': jnp.zeros(10)
}

# Training loop
for epoch in range(num_epochs):
    for batch_data, batch_labels in dataloader:
        # Forward pass
        predictions = model(params, batch_data)

        # Compute loss
        loss = jnp.mean((predictions - batch_labels) ** 2)

        # Backward pass (using JAX's grad)
        grads = jax.grad(lambda p: jnp.mean((model(p, batch_data) - batch_labels) ** 2))(params)

        # Update parameters
        params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

Large-Scale Data Processing

This tutorial demonstrates how to handle large datasets efficiently.

  1. Set up memory-efficient data loading:

from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import CSVLoader

# Create CSV loader with chunking
loader = CSVLoader(
    "large_dataset.csv",
    chunk_size=10000,  # Process in chunks
    target_column="target"
)

# Configure for memory efficiency
config = DataLoaderConfig(
    batch_size=32,
    memory_fraction=0.8,
    auto_batch_size=True,
    cache_size=1000,
    num_workers=4
)

dataloader = DataLoader(
    loader=loader,
    config=config
)
  1. Process data in batches:

# Enable memory optimization
dataloader.optimize_memory()

# Process data
for batch_data, batch_labels in dataloader:
    # Process batch
    process_batch(batch_data, batch_labels)

    # Monitor memory usage
    print(f"Memory usage: {dataloader.memory_manager.get_memory_usage()}")

Multi-GPU Training

Learn how to distribute training across multiple GPUs.

  1. Set up multi-GPU configuration:

import jax
from jax_dataloader import DataLoader, DataLoaderConfig

# Get available devices
devices = jax.devices()

# Create sample data
data = jnp.arange(10000)
labels = jnp.arange(10000)

# Configure for multi-GPU
config = DataLoaderConfig(
    batch_size=32,
    num_devices=len(devices),
    device_map="auto",
    pin_memory=True
)

dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)
  1. Implement distributed training:

# Training loop
for batch_data, batch_labels in dataloader:
    # batch_data and batch_labels are already on the correct devices
    data, device_id = batch_data

    # Train on specific device
    with jax.device(devices[device_id]):
        # Your training code here
        train_step(data, batch_labels)

Custom Data Augmentation

Learn how to create custom data augmentation pipelines.

  1. Define augmentation functions:

import jax.random as random
import jax.numpy as jnp

def custom_augment(batch, key):
    # Split key for multiple augmentations
    key1, key2, key3 = random.split(key, 3)

    # Add noise
    noise = random.normal(key1, batch.shape) * 0.1
    augmented = batch + noise

    # Random rotation
    angle = random.uniform(key2, minval=-0.1, maxval=0.1)
    augmented = jnp.rot90(augmented, k=int(angle * 10))

    # Random flip
    if random.uniform(key3) > 0.5:
        augmented = jnp.flip(augmented, axis=1)

    return augmented
  1. Apply custom augmentations:

from jax_dataloader import DataLoader, DataLoaderConfig

# Configure with custom augmentation
config = DataLoaderConfig(
    batch_size=32,
    transform=custom_augment,
    transform_key=random.PRNGKey(0)
)

dataloader = DataLoader(
    data=data,
    config=config
)
  1. Use in training:

for batch_data, batch_labels in dataloader:
    # batch_data is already augmented
    train_step(batch_data, batch_labels)

Advanced Caching Strategies

Learn how to optimize data loading with advanced caching.

  1. Set up caching:

from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.memory import Cache

# Create cache
cache = Cache(
    max_size=1000,  # Maximum number of batches to cache
    eviction_policy="lru"  # Least Recently Used
)

# Configure dataloader with cache
config = DataLoaderConfig(
    batch_size=32,
    cache=cache,
    cache_hits=True  # Track cache hits
)

dataloader = DataLoader(
    data=data,
    config=config
)
  1. Monitor cache performance:

for batch_data, batch_labels in dataloader:
    # Process batch
    process_batch(batch_data, batch_labels)

    # Print cache statistics
    print(f"Cache hits: {dataloader.cache.hits}")
    print(f"Cache misses: {dataloader.cache.misses}")
    print(f"Hit rate: {dataloader.cache.hit_rate}")