Welcome to JAX DataLoader’s documentation!
JAX DataLoader is a high-performance data loading library for JAX that provides efficient data loading and preprocessing capabilities. It is designed to be simple, fast, and memory-efficient, making it perfect for deep learning and data science workflows.
Features
High Performance: Optimized data loading with minimal overhead
Memory Efficient: Smart memory management and data streaming
Flexible: Support for various data formats (CSV, JSON, Images)
Easy to Use: Simple API with familiar interface
Type Safe: Full type hints and static type checking
Extensible: Easy to add custom data loaders
Benchmarking: Comprehensive performance analysis tools for CPU and multi-device setups
Performance Optimization: Advanced tools for analyzing and optimizing data loading performance
Quick Start
from jax_dataloader import DataLoader
import jax.numpy as jnp
# Create a simple dataset
data = jnp.array([1, 2, 3, 4, 5])
dataset = DataLoader(data, batch_size=2)
# Iterate over batches
for batch in dataset:
print(batch)
Installation
pip install jax-dataloaders
For development installation:
git clone https://github.com/carrycooldude/JAX-Dataloader.git
cd JAX-Dataloader
pip install -e ".[dev]"
Documentation Contents
Indices and tables
Getting Help
If you encounter any issues or have questions:
Open an issue on GitHub
Check the examples for common use cases
Join our Discussions
Contributing
We welcome contributions! Please see our Contributing Guide for details.