Welcome to JAX DataLoader’s documentation!

JAX DataLoader is a high-performance data loading library for JAX that provides efficient data loading and preprocessing capabilities. It is designed to be simple, fast, and memory-efficient, making it perfect for deep learning and data science workflows.

Features

  • High Performance: Optimized data loading with minimal overhead

  • Memory Efficient: Smart memory management and data streaming

  • Flexible: Support for various data formats (CSV, JSON, Images)

  • Easy to Use: Simple API with familiar interface

  • Type Safe: Full type hints and static type checking

  • Extensible: Easy to add custom data loaders

  • Benchmarking: Comprehensive performance analysis tools for CPU and multi-device setups

  • Performance Optimization: Advanced tools for analyzing and optimizing data loading performance

Quick Start

from jax_dataloader import DataLoader
import jax.numpy as jnp

# Create a simple dataset
data = jnp.array([1, 2, 3, 4, 5])
dataset = DataLoader(data, batch_size=2)

# Iterate over batches
for batch in dataset:
    print(batch)

Installation

pip install jax-dataloaders

For development installation:

git clone https://github.com/carrycooldude/JAX-Dataloader.git
cd JAX-Dataloader
pip install -e ".[dev]"

Documentation Contents

Indices and tables

Getting Help

If you encounter any issues or have questions:

Contributing

We welcome contributions! Please see our Contributing Guide for details.