Installation Guide
Requirements
JAX DataLoader requires Python 3.7 or later and the following dependencies:
JAX >= 0.3.0
JAXlib >= 0.3.0
NumPy >= 1.19.0
Pandas >= 1.2.0
Pillow >= 8.0.0
psutil >= 5.8.0
tqdm >= 4.50.0
Installation Methods
Using pip
The easiest way to install JAX DataLoader is using pip:
pip install jax-dataloaders
Development Installation
For development or to get the latest features, you can install from source:
git clone https://github.com/carrycooldude/JAX-Dataloader.git
cd JAX-Dataloader
pip install -e .
Using conda
You can also install JAX DataLoader using conda:
conda install -c conda-forge jax-dataloaders
Verifying Installation
To verify that JAX DataLoader is installed correctly:
from jax_dataloader import DataLoader
print(DataLoader.__version__)
Troubleshooting
Common Issues
JAX Installation - If you encounter issues with JAX installation, refer to the JAX installation guide. - For CUDA support, make sure you have the correct version of CUDA installed.
Memory Issues - If you encounter memory errors, try reducing the batch size or enabling memory management. - Use the memory_limit parameter in DataLoaderConfig to control memory usage.
Multi-GPU Support - Ensure JAX is properly configured for multi-GPU usage. - Check that your batch size is compatible with the number of devices.
Getting Help
If you encounter any issues:
Check the GitHub issues to see if your problem has been reported.
If not, create a new issue with details about your problem.
Join our Discord community for real-time support.