Tutorials ========= This section provides step-by-step tutorials for common use cases with JAX DataLoader. Getting Started with Image Classification --------------------------------------- In this tutorial, we'll create a complete image classification pipeline using JAX DataLoader. 1. First, let's set up our environment: .. code-block:: python import jax import jax.numpy as jnp from jax_dataloader import DataLoader, DataLoaderConfig from jax_dataloader.data import ImageLoader 2. Create a data loader for our image dataset: .. code-block:: python # Create image loader loader = ImageLoader( "path/to/image/dataset", image_size=(224, 224), normalize=True, augment=True # Enable built-in augmentations ) # Configure the dataloader config = DataLoaderConfig( batch_size=32, shuffle=True, num_workers=4, pin_memory=True ) # Create the dataloader dataloader = DataLoader( loader=loader, config=config ) 3. Define a simple model: .. code-block:: python def model(params, x): # Simple CNN x = jax.nn.relu(jnp.dot(x, params['w1']) + params['b1']) x = jax.nn.relu(jnp.dot(x, params['w2']) + params['b2']) return jnp.dot(x, params['w3']) + params['b3'] 4. Training loop: .. code-block:: python # Initialize parameters params = { 'w1': jax.random.normal(jax.random.PRNGKey(0), (224*224*3, 128)), 'b1': jnp.zeros(128), 'w2': jax.random.normal(jax.random.PRNGKey(1), (128, 64)), 'b2': jnp.zeros(64), 'w3': jax.random.normal(jax.random.PRNGKey(2), (64, 10)), 'b3': jnp.zeros(10) } # Training loop for epoch in range(num_epochs): for batch_data, batch_labels in dataloader: # Forward pass predictions = model(params, batch_data) # Compute loss loss = jnp.mean((predictions - batch_labels) ** 2) # Backward pass (using JAX's grad) grads = jax.grad(lambda p: jnp.mean((model(p, batch_data) - batch_labels) ** 2))(params) # Update parameters params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads) Large-Scale Data Processing ------------------------- This tutorial demonstrates how to handle large datasets efficiently. 1. Set up memory-efficient data loading: .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig from jax_dataloader.data import CSVLoader # Create CSV loader with chunking loader = CSVLoader( "large_dataset.csv", chunk_size=10000, # Process in chunks target_column="target" ) # Configure for memory efficiency config = DataLoaderConfig( batch_size=32, memory_fraction=0.8, auto_batch_size=True, cache_size=1000, num_workers=4 ) dataloader = DataLoader( loader=loader, config=config ) 2. Process data in batches: .. code-block:: python # Enable memory optimization dataloader.optimize_memory() # Process data for batch_data, batch_labels in dataloader: # Process batch process_batch(batch_data, batch_labels) # Monitor memory usage print(f"Memory usage: {dataloader.memory_manager.get_memory_usage()}") Multi-GPU Training ---------------- Learn how to distribute training across multiple GPUs. 1. Set up multi-GPU configuration: .. code-block:: python import jax from jax_dataloader import DataLoader, DataLoaderConfig # Get available devices devices = jax.devices() # Create sample data data = jnp.arange(10000) labels = jnp.arange(10000) # Configure for multi-GPU config = DataLoaderConfig( batch_size=32, num_devices=len(devices), device_map="auto", pin_memory=True ) dataloader = DataLoader( data=data, labels=labels, config=config ) 2. Implement distributed training: .. code-block:: python # Training loop for batch_data, batch_labels in dataloader: # batch_data and batch_labels are already on the correct devices data, device_id = batch_data # Train on specific device with jax.device(devices[device_id]): # Your training code here train_step(data, batch_labels) Custom Data Augmentation ---------------------- Learn how to create custom data augmentation pipelines. 1. Define augmentation functions: .. code-block:: python import jax.random as random import jax.numpy as jnp def custom_augment(batch, key): # Split key for multiple augmentations key1, key2, key3 = random.split(key, 3) # Add noise noise = random.normal(key1, batch.shape) * 0.1 augmented = batch + noise # Random rotation angle = random.uniform(key2, minval=-0.1, maxval=0.1) augmented = jnp.rot90(augmented, k=int(angle * 10)) # Random flip if random.uniform(key3) > 0.5: augmented = jnp.flip(augmented, axis=1) return augmented 2. Apply custom augmentations: .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig # Configure with custom augmentation config = DataLoaderConfig( batch_size=32, transform=custom_augment, transform_key=random.PRNGKey(0) ) dataloader = DataLoader( data=data, config=config ) 3. Use in training: .. code-block:: python for batch_data, batch_labels in dataloader: # batch_data is already augmented train_step(batch_data, batch_labels) Advanced Caching Strategies ------------------------- Learn how to optimize data loading with advanced caching. 1. Set up caching: .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig from jax_dataloader.memory import Cache # Create cache cache = Cache( max_size=1000, # Maximum number of batches to cache eviction_policy="lru" # Least Recently Used ) # Configure dataloader with cache config = DataLoaderConfig( batch_size=32, cache=cache, cache_hits=True # Track cache hits ) dataloader = DataLoader( data=data, config=config ) 2. Monitor cache performance: .. code-block:: python for batch_data, batch_labels in dataloader: # Process batch process_batch(batch_data, batch_labels) # Print cache statistics print(f"Cache hits: {dataloader.cache.hits}") print(f"Cache misses: {dataloader.cache.misses}") print(f"Hit rate: {dataloader.cache.hit_rate}")