Source code for jax_dataloader.progress

"""Progress tracking module for JAX applications."""

from typing import Optional, List, Callable
from tqdm import tqdm
import time

[docs] class ProgressTracker: """Tracks progress of data loading."""
[docs] def __init__( self, total: int, desc: Optional[str] = None, unit: str = "it", leave: bool = True, update_interval: float = 0.1, callbacks: Optional[List[Callable[[float], None]]] = None, show_eta: bool = True, ): """Initialize the progress tracker. Args: total: Total number of items desc: Description of the progress unit: Unit of progress leave: Whether to leave the progress bar update_interval: Time interval between updates in seconds callbacks: List of callback functions to call on updates show_eta: Whether to show estimated time remaining """ self.total = total self.current = 0 self.start_time = time.time() self.update_interval = update_interval self.callbacks = callbacks or [] self.show_eta = show_eta self.pbar = tqdm( total=total, desc=desc, unit=unit, leave=leave, )
[docs] def update(self, n: int = 1): """Update the progress. Args: n: Number of items to update """ self.current += n self.pbar.update(n) # Call callbacks if enough time has passed if time.time() - self.start_time >= self.update_interval: progress = self.get_progress() for callback in self.callbacks: callback(progress) self.start_time = time.time()
[docs] def reset(self): """Reset the progress tracker.""" self.current = 0 self.start_time = time.time() self.pbar.reset()
[docs] def get_progress(self) -> float: """Get the current progress as a fraction. Returns: Progress as a fraction between 0 and 1 """ return self.current / self.total if self.total > 0 else 0
[docs] def get_eta(self) -> float: """Get the estimated time remaining. Returns: Estimated time remaining in seconds """ if self.current == 0: return float('inf') elapsed = time.time() - self.start_time rate = self.current / elapsed if elapsed > 0 else 0 remaining = self.total - self.current return remaining / rate if rate > 0 else float('inf')
[docs] def close(self): """Close the progress bar.""" self.pbar.close()
[docs] def __enter__(self): """Context manager entry.""" return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close()