Examples ======== Basic Examples ------------- Simple Data Loading ~~~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig # Create a configuration config = DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32, shuffle=True ) # Create a data loader loader = DataLoader(config) # Iterate over batches for batch_data, batch_labels in loader: # Process your batch here pass Loading from Files ----------------- CSV Data ~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig # Create a configuration for CSV data config = DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32, shuffle=True, target_column="label" ) # Create a data loader loader = DataLoader(config) # Get metadata about the dataset metadata = loader.get_metadata() print(f"Number of samples: {metadata['num_samples']}") print(f"Number of features: {metadata['num_features']}") print(f"Feature names: {metadata['feature_names']}") # Iterate over batches for batch_data, batch_labels in loader: # Process your batch here pass JSON Data ~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig import jax.numpy as jnp # Create configuration config = DataLoaderConfig( loader_type="json", data_path="data.json", data_key="features", label_key="labels", batch_size=32, shuffle=True ) # Create data loader dataloader = DataLoader(config) # Iterate over batches for batch_data, batch_labels in dataloader: print(f"Batch shape: {batch_data.shape}") print(f"Labels shape: {batch_labels.shape}") Image Data ~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig # Create a configuration for image data config = DataLoaderConfig( data_path="data/images/", loader_type="image", batch_size=16, shuffle=True, image_size=(224, 224) ) # Create a data loader loader = DataLoader(config) # Iterate over batches for batch_images, batch_labels in loader: # Process your batch here pass Advanced Examples -------------- Multi-GPU Training ~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig import jax # Create a configuration config = DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32 * jax.device_count(), # Scale batch size by number of devices shuffle=True ) # Create a data loader loader = DataLoader(config) # Your training function @jax.pmap def train_step(params, batch): # Your training logic here pass Data Augmentation ~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig, Transform # Create transformations transform = Transform() transform.add("random_flip", probability=0.5) transform.add("random_rotation", max_angle=30) transform.add("random_brightness", max_delta=0.2) # Create a configuration with transformations config = DataLoaderConfig( data_path="data/images/", loader_type="image", batch_size=16, shuffle=True, transform=transform ) # Create a data loader loader = DataLoader(config) Memory Management ~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig # Create a configuration with memory management config = DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32, shuffle=True, memory_limit="4GB", # Limit memory usage cache_size="1GB" # Set cache size ) # Create a data loader loader = DataLoader(config) # Monitor memory usage memory_stats = loader.get_memory_usage() print(f"Current memory usage: {memory_stats['current_usage']}") print(f"Peak memory usage: {memory_stats['peak_usage']}") Progress Tracking ~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig # Create a configuration with progress tracking config = DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32, shuffle=True, show_progress=True # Enable progress tracking ) # Create a data loader loader = DataLoader(config) # Get progress information progress = loader.get_progress() print(f"Current batch: {progress['current_batch']}") print(f"Total batches: {progress['total_batches']}") print(f"Progress: {progress['progress']:.2%}") print(f"ETA: {progress['eta']:.2f} seconds") Error Handling ~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig from jax_dataloader.exceptions import DataLoaderError, ConfigurationError try: # Create a configuration config = DataLoaderConfig( data_path="nonexistent.csv", loader_type="csv", batch_size=32 ) # Create a data loader loader = DataLoader(config) except ConfigurationError as e: print(f"Configuration error: {e}") except DataLoaderError as e: print(f"Data loader error: {e}") except Exception as e: print(f"Unexpected error: {e}") For more examples and use cases, check out the `GitHub repository `_. Benchmarking ----------- Performance Analysis ~~~~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig from jax_dataloader.benchmark import BenchmarkRunner # Create a configuration for benchmarking config = DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32, shuffle=True ) # Create a benchmark runner benchmark = BenchmarkRunner(config) # Run CPU performance analysis results = benchmark.run_cpu_analysis( num_iterations=100, warmup_iterations=10 ) # Print results print(f"Average batch loading time: {results['avg_batch_time']:.4f} seconds") print(f"Memory usage: {results['memory_usage']}") print(f"CPU utilization: {results['cpu_utilization']:.2f}%") Multi-Device Benchmarking ~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: python from jax_dataloader import DataLoader, DataLoaderConfig from jax_dataloader.benchmark import BenchmarkRunner import jax # Create configurations for different batch sizes configs = [ DataLoaderConfig( data_path="data/train.csv", loader_type="csv", batch_size=32 * (2 ** i), # Test different batch sizes shuffle=True ) for i in range(4) # Test 32, 64, 128, 256 batch sizes ] # Run benchmarks for each configuration for config in configs: benchmark = BenchmarkRunner(config) results = benchmark.run_multi_device_analysis( num_devices=jax.device_count(), num_iterations=50 ) print(f"\nResults for batch size {config.batch_size}:") print(f"Throughput: {results['throughput']:.2f} samples/second") print(f"GPU utilization: {results['gpu_utilization']:.2f}%") print(f"Memory efficiency: {results['memory_efficiency']:.2f}%")