Examples
Basic Examples
Simple Data Loading
from jax_dataloader import DataLoader, DataLoaderConfig
# Create a configuration
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
shuffle=True
)
# Create a data loader
loader = DataLoader(config)
# Iterate over batches
for batch_data, batch_labels in loader:
# Process your batch here
pass
Loading from Files
CSV Data
from jax_dataloader import DataLoader, DataLoaderConfig
# Create a configuration for CSV data
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
shuffle=True,
target_column="label"
)
# Create a data loader
loader = DataLoader(config)
# Get metadata about the dataset
metadata = loader.get_metadata()
print(f"Number of samples: {metadata['num_samples']}")
print(f"Number of features: {metadata['num_features']}")
print(f"Feature names: {metadata['feature_names']}")
# Iterate over batches
for batch_data, batch_labels in loader:
# Process your batch here
pass
JSON Data
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
# Create configuration
config = DataLoaderConfig(
loader_type="json",
data_path="data.json",
data_key="features",
label_key="labels",
batch_size=32,
shuffle=True
)
# Create data loader
dataloader = DataLoader(config)
# Iterate over batches
for batch_data, batch_labels in dataloader:
print(f"Batch shape: {batch_data.shape}")
print(f"Labels shape: {batch_labels.shape}")
Image Data
from jax_dataloader import DataLoader, DataLoaderConfig
# Create a configuration for image data
config = DataLoaderConfig(
data_path="data/images/",
loader_type="image",
batch_size=16,
shuffle=True,
image_size=(224, 224)
)
# Create a data loader
loader = DataLoader(config)
# Iterate over batches
for batch_images, batch_labels in loader:
# Process your batch here
pass
Advanced Examples
Multi-GPU Training
from jax_dataloader import DataLoader, DataLoaderConfig
import jax
# Create a configuration
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32 * jax.device_count(), # Scale batch size by number of devices
shuffle=True
)
# Create a data loader
loader = DataLoader(config)
# Your training function
@jax.pmap
def train_step(params, batch):
# Your training logic here
pass
Data Augmentation
from jax_dataloader import DataLoader, DataLoaderConfig, Transform
# Create transformations
transform = Transform()
transform.add("random_flip", probability=0.5)
transform.add("random_rotation", max_angle=30)
transform.add("random_brightness", max_delta=0.2)
# Create a configuration with transformations
config = DataLoaderConfig(
data_path="data/images/",
loader_type="image",
batch_size=16,
shuffle=True,
transform=transform
)
# Create a data loader
loader = DataLoader(config)
Memory Management
from jax_dataloader import DataLoader, DataLoaderConfig
# Create a configuration with memory management
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
shuffle=True,
memory_limit="4GB", # Limit memory usage
cache_size="1GB" # Set cache size
)
# Create a data loader
loader = DataLoader(config)
# Monitor memory usage
memory_stats = loader.get_memory_usage()
print(f"Current memory usage: {memory_stats['current_usage']}")
print(f"Peak memory usage: {memory_stats['peak_usage']}")
Progress Tracking
from jax_dataloader import DataLoader, DataLoaderConfig
# Create a configuration with progress tracking
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
shuffle=True,
show_progress=True # Enable progress tracking
)
# Create a data loader
loader = DataLoader(config)
# Get progress information
progress = loader.get_progress()
print(f"Current batch: {progress['current_batch']}")
print(f"Total batches: {progress['total_batches']}")
print(f"Progress: {progress['progress']:.2%}")
print(f"ETA: {progress['eta']:.2f} seconds")
Error Handling
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.exceptions import DataLoaderError, ConfigurationError
try:
# Create a configuration
config = DataLoaderConfig(
data_path="nonexistent.csv",
loader_type="csv",
batch_size=32
)
# Create a data loader
loader = DataLoader(config)
except ConfigurationError as e:
print(f"Configuration error: {e}")
except DataLoaderError as e:
print(f"Data loader error: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
For more examples and use cases, check out the GitHub repository.
Benchmarking
Performance Analysis
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.benchmark import BenchmarkRunner
# Create a configuration for benchmarking
config = DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32,
shuffle=True
)
# Create a benchmark runner
benchmark = BenchmarkRunner(config)
# Run CPU performance analysis
results = benchmark.run_cpu_analysis(
num_iterations=100,
warmup_iterations=10
)
# Print results
print(f"Average batch loading time: {results['avg_batch_time']:.4f} seconds")
print(f"Memory usage: {results['memory_usage']}")
print(f"CPU utilization: {results['cpu_utilization']:.2f}%")
Multi-Device Benchmarking
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.benchmark import BenchmarkRunner
import jax
# Create configurations for different batch sizes
configs = [
DataLoaderConfig(
data_path="data/train.csv",
loader_type="csv",
batch_size=32 * (2 ** i), # Test different batch sizes
shuffle=True
) for i in range(4) # Test 32, 64, 128, 256 batch sizes
]
# Run benchmarks for each configuration
for config in configs:
benchmark = BenchmarkRunner(config)
results = benchmark.run_multi_device_analysis(
num_devices=jax.device_count(),
num_iterations=50
)
print(f"\nResults for batch size {config.batch_size}:")
print(f"Throughput: {results['throughput']:.2f} samples/second")
print(f"GPU utilization: {results['gpu_utilization']:.2f}%")
print(f"Memory efficiency: {results['memory_efficiency']:.2f}%")