"""Data loading module for JAX applications."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union, Tuple, Callable
import jax.numpy as jnp
from jax import random
import pandas as pd
import numpy as np
import json
[docs]
class BaseLoader(ABC):
"""Base class for all data loaders."""
[docs]
def __init__(
self,
batch_size: int = 32,
shuffle: bool = True,
seed: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: int = 2,
):
"""Initialize the data loader.
Args:
batch_size: Number of samples per batch
shuffle: Whether to shuffle the data
seed: Random seed for reproducibility
num_workers: Number of worker processes
prefetch_factor: Number of batches to prefetch
"""
self.batch_size = batch_size
self.shuffle = shuffle
self.seed = seed
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self._rng = random.PRNGKey(seed) if seed is not None else random.PRNGKey(0)
self._metadata: Dict[str, Any] = {}
[docs]
@abstractmethod
def __len__(self) -> int:
"""Return the number of batches."""
pass
[docs]
@abstractmethod
def __iter__(self):
"""Return an iterator over the data."""
pass
[docs]
def __next__(self):
"""Get the next batch of data."""
pass
[docs]
def load(self, data_path: str) -> Any:
"""Load data from the specified path.
Args:
data_path: Path to the data file
Returns:
Loaded data
"""
raise NotImplementedError("Subclasses must implement load()")
[docs]
def preprocess(self, data: Any) -> Any:
"""Preprocess the loaded data.
Args:
data: Data to preprocess
Returns:
Preprocessed data
"""
return data
[docs]
class CSVLoader(BaseLoader):
"""Loader for CSV data."""
[docs]
def __init__(
self,
data_path: str,
target_column: str,
feature_columns: Optional[List[str]] = None,
chunk_size: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
batch_size: int = 32,
shuffle: bool = True,
seed: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: int = 2,
):
"""Initialize the CSV loader.
Args:
data_path: Path to the CSV file
target_column: Name of the target column
feature_columns: List of feature column names
chunk_size: Size of chunks to load at once
dtype: Optional data type for arrays
batch_size: Number of samples per batch
shuffle: Whether to shuffle the data
seed: Random seed for reproducibility
num_workers: Number of worker processes
prefetch_factor: Number of batches to prefetch
"""
super().__init__(batch_size, shuffle, seed, num_workers, prefetch_factor)
self.data_path = data_path
self.target_column = target_column
self.feature_columns = feature_columns
self.chunk_size = chunk_size
self.dtype = dtype
self._data = None
self._labels = None
[docs]
def load(self, data_path: str) -> Any:
"""Load data from the CSV file.
Args:
data_path: Path to the CSV file
Returns:
Loaded data
"""
if self.chunk_size is not None:
# Load data in chunks
chunks = []
labels = []
for chunk in pd.read_csv(data_path, chunksize=self.chunk_size):
if self.feature_columns is not None:
chunk_data = chunk[self.feature_columns].values
else:
chunk_data = chunk.drop(columns=[self.target_column]).values
chunks.append(chunk_data)
labels.append(chunk[self.target_column].values)
self._data = jnp.concatenate(chunks, axis=0, dtype=self.dtype)
self._labels = jnp.concatenate(labels, axis=0, dtype=self.dtype)
else:
# Load entire file at once
df = pd.read_csv(data_path)
if self.feature_columns is not None:
self._data = jnp.array(df[self.feature_columns].values, dtype=self.dtype)
else:
self._data = jnp.array(df.drop(columns=[self.target_column]).values, dtype=self.dtype)
self._labels = jnp.array(df[self.target_column].values, dtype=self.dtype)
return self._data
[docs]
def preprocess(self, data: Any) -> Any:
"""Preprocess the loaded data.
Args:
data: Data to preprocess
Returns:
Preprocessed data
"""
return data
[docs]
def __len__(self) -> int:
"""Return the number of batches."""
if self._data is None:
self.load(self.data_path)
return (len(self._data) + self.batch_size - 1) // self.batch_size
[docs]
def __iter__(self):
"""Return an iterator over the data."""
if self._data is None:
self.load(self.data_path)
indices = np.arange(len(self._data))
if self.shuffle:
np.random.shuffle(indices)
for i in range(0, len(indices), self.batch_size):
batch_indices = indices[i:i + self.batch_size]
yield self._data[batch_indices], self._labels[batch_indices]
[docs]
def get_chunk(self, start: int, size: int) -> Tuple[np.ndarray, np.ndarray]:
"""Get a chunk of data from the CSV file.
Args:
start (int): The starting index of the chunk.
size (int): The size of the chunk.
Returns:
tuple: A tuple containing the chunk data and labels as numpy arrays.
"""
df = pd.read_csv(self.data_path, skiprows=range(1, start + 1), nrows=size)
data = df.drop(columns=[self.target_column]).values
labels = df[self.target_column].values
return data, labels
[docs]
class JSONLoader(BaseLoader):
"""Loader for JSON data."""
[docs]
def __init__(
self,
data_path: str,
data_key: str = "data",
label_key: str = "labels",
preprocess_fn: Optional[Callable] = None,
dtype: Optional[jnp.dtype] = None,
batch_size: int = 32,
shuffle: bool = True,
seed: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: int = 2,
):
"""Initialize the JSON loader.
Args:
data_path: Path to the JSON file
data_key: Key for data in JSON
label_key: Key for labels in JSON
preprocess_fn: Optional preprocessing function
dtype: Optional data type for arrays
batch_size: Number of samples per batch
shuffle: Whether to shuffle the data
seed: Random seed for reproducibility
num_workers: Number of worker processes
prefetch_factor: Number of batches to prefetch
"""
super().__init__(batch_size, shuffle, seed, num_workers, prefetch_factor)
self.data_path = data_path
self.data_key = data_key
self.label_key = label_key
self.preprocess_fn = preprocess_fn
self.dtype = dtype
self._data = None
self._labels = None
[docs]
def load(self, data_path: str) -> Any:
"""Load data from the JSON file.
Args:
data_path: Path to the JSON file
Returns:
Loaded data
"""
with open(data_path, 'r') as f:
data = json.load(f)
self._data = jnp.array(data[self.data_key], dtype=self.dtype)
self._labels = jnp.array(data[self.label_key], dtype=self.dtype)
return self._data
[docs]
def preprocess(self, data: Any) -> Any:
"""Preprocess the loaded data.
Args:
data: Data to preprocess
Returns:
Preprocessed data
"""
if self.preprocess_fn is not None:
return self.preprocess_fn(data)
return data
[docs]
def __len__(self) -> int:
"""Return the number of batches."""
if self._data is None:
self.load(self.data_path)
return (len(self._data) + self.batch_size - 1) // self.batch_size
[docs]
def __iter__(self):
"""Return an iterator over the data."""
if self._data is None:
self.load(self.data_path)
indices = np.arange(len(self._data))
if self.shuffle:
np.random.shuffle(indices)
for i in range(0, len(indices), self.batch_size):
batch_indices = indices[i:i + self.batch_size]
yield self._data[batch_indices], self._labels[batch_indices]
[docs]
class ImageLoader(BaseLoader):
"""Loader for image data."""
[docs]
def __init__(
self,
data_path: str,
image_size: Tuple[int, int] = (224, 224),
normalize: bool = True,
augment: bool = False,
augment_options: Optional[Dict[str, Any]] = None,
batch_size: int = 32,
shuffle: bool = True,
seed: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: int = 2,
):
"""Initialize the image loader.
Args:
data_path: Path to the image directory
image_size: Target size for images
normalize: Whether to normalize pixel values
augment: Whether to apply data augmentation
augment_options: Options for data augmentation
batch_size: Number of samples per batch
shuffle: Whether to shuffle the data
seed: Random seed for reproducibility
num_workers: Number of worker processes
prefetch_factor: Number of batches to prefetch
"""
super().__init__(batch_size, shuffle, seed, num_workers, prefetch_factor)
self.data_path = data_path
self.image_size = image_size
self.normalize = normalize
self.augment = augment
self.augment_options = augment_options or {}
self._image_files = []
self._labels = []
[docs]
def load(self, data_path: str) -> Any:
"""Load data from the image directory.
Args:
data_path: Path to the image directory
Returns:
Loaded data
"""
import os
from PIL import Image
# Get all image files
self._image_files = []
self._labels = []
for root, _, files in os.walk(data_path):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
self._image_files.append(os.path.join(root, file))
# Use directory name as label
self._labels.append(os.path.basename(root))
return self._image_files
[docs]
def preprocess(self, data: Any) -> Any:
"""Preprocess the loaded data.
Args:
data: Data to preprocess
Returns:
Preprocessed data
"""
from PIL import Image
import numpy as np
# Load and resize image
img = Image.open(data).convert('RGB')
img = img.resize(self.image_size)
img_array = np.array(img)
# Normalize if requested
if self.normalize:
img_array = img_array.astype(np.float32) / 255.0
return jnp.array(img_array)
[docs]
def augment(self, data: Any) -> Any:
"""Apply data augmentation to the image.
Args:
data: Image data to augment
Returns:
Augmented image data
"""
if not self.augment:
return data
import jax.random as random
key = random.PRNGKey(0)
# Apply random rotation
if "rotation" in self.augment_options:
angle = random.uniform(
key,
minval=self.augment_options["rotation"][0],
maxval=self.augment_options["rotation"][1]
)
data = jnp.rot90(data, k=int(angle))
# Apply random flip
if self.augment_options.get("flip", False):
key, subkey = random.split(key)
if random.uniform(subkey) > 0.5:
data = jnp.fliplr(data)
# Apply brightness adjustment
if "brightness" in self.augment_options:
key, subkey = random.split(key)
factor = random.uniform(
subkey,
minval=self.augment_options["brightness"][0],
maxval=self.augment_options["brightness"][1]
)
data = data * factor
return data
[docs]
def __len__(self) -> int:
"""Return the number of batches."""
if not self._image_files:
self.load(self.data_path)
return (len(self._image_files) + self.batch_size - 1) // self.batch_size
[docs]
def __iter__(self):
"""Return an iterator over the data."""
if not self._image_files:
self.load(self.data_path)
indices = np.arange(len(self._image_files))
if self.shuffle:
np.random.shuffle(indices)
for i in range(0, len(indices), self.batch_size):
batch_indices = indices[i:i + self.batch_size]
batch_images = []
batch_labels = []
for idx in batch_indices:
img = self.preprocess(self._image_files[idx])
if self.augment:
img = self.augment(img)
batch_images.append(img)
batch_labels.append(self._labels[idx])
yield jnp.array(batch_images), jnp.array(batch_labels)
def get_device_count() -> int:
"""Get the number of available devices.
Returns:
Number of available devices
"""
# TODO: Implement device counting logic
return 1