Source code for jax_dataloader.memory

"""Memory management module for JAX applications."""

from typing import Any, Dict, Optional, Union
import psutil
import numpy as np
import time

[docs] class MemoryManager: """Manages memory allocation and deallocation."""
[docs] def __init__(self, max_memory: Optional[float] = None): """Initialize the memory manager. Args: max_memory: Maximum memory to allocate in bytes """ self.max_memory = max_memory or psutil.virtual_memory().total self.allocated = 0 self._peak_usage = 0 self._last_update = time.time()
[docs] def allocate(self, size: int) -> bool: """Allocate memory. Args: size: Size to allocate in bytes Returns: True if allocation was successful """ if self.allocated + size > self.max_memory: return False self.allocated += size self._peak_usage = max(self._peak_usage, self.allocated) return True
[docs] def deallocate(self, size: int): """Deallocate memory. Args: size: Size to deallocate in bytes """ self.allocated = max(0, self.allocated - size)
[docs] def free(self, size: Optional[int] = None): """Free memory. Args: size: Optional size to free in bytes """ if size is not None: self.deallocate(size) else: self.allocated = 0
[docs] def get_usage(self) -> Dict[str, Any]: """Get current memory usage statistics. Returns: Dictionary containing memory usage statistics """ return { "allocated": self.allocated, "peak_usage": self._peak_usage, "available": self.max_memory - self.allocated, "total": self.max_memory }
[docs] def cleanup(self): """Clean up memory and reset statistics.""" self.allocated = 0 self._peak_usage = 0 self._last_update = time.time()
[docs] def monitor(self, interval: float = 1.0) -> Dict[str, Any]: """Monitor memory usage over time. Args: interval: Time interval between updates in seconds Returns: Dictionary containing runtime statistics """ current_time = time.time() if current_time - self._last_update >= interval: stats = self.get_usage() stats["timestamp"] = current_time self._last_update = current_time return stats return {}
[docs] class Cache: """Cache for storing data in memory."""
[docs] def __init__( self, max_size: int, eviction_policy: str = "lru", track_stats: bool = True, max_age: Optional[float] = None, ): """Initialize the cache. Args: max_size: Maximum size of the cache in bytes eviction_policy: Cache eviction policy ("lru" or "fifo") track_stats: Whether to track cache statistics max_age: Maximum age of cached items in seconds """ self.max_size = max_size self.eviction_policy = eviction_policy self.track_stats = track_stats self.max_age = max_age self._cache: Dict[str, Any] = {} self._sizes: Dict[str, int] = {} self._timestamps: Dict[str, float] = {} self._hits = 0 self._misses = 0 self._evictions = 0
[docs] def get(self, key: str) -> Optional[Any]: """Get a value from the cache. Args: key: Cache key Returns: Cached value if found, None otherwise """ if key in self._cache: if self.max_age is not None: if time.time() - self._timestamps[key] > self.max_age: self.evict(key) if self.track_stats: self._misses += 1 return None if self.track_stats: self._hits += 1 return self._cache[key] if self.track_stats: self._misses += 1 return None
[docs] def put(self, key: str, value: Any): """Put a value in the cache. Args: key: Cache key value: Value to cache """ size = self._estimate_size(value) # Evict items if needed while self._total_size + size > self.max_size: self._evict() self._cache[key] = value self._sizes[key] = size self._timestamps[key] = time.time()
[docs] def clear(self): """Clear the cache.""" self._cache.clear() self._sizes.clear() self._timestamps.clear() if self.track_stats: self._hits = 0 self._misses = 0 self._evictions = 0
[docs] def get_stats(self) -> Dict[str, Any]: """Get cache statistics. Returns: Dictionary containing cache statistics """ return { "size": self._total_size, "items": len(self._cache), "hits": self._hits, "misses": self._misses, "evictions": self._evictions, "hit_rate": self._hits / (self._hits + self._misses) if self._hits + self._misses > 0 else 0 }
[docs] def evict(self, key: str): """Evict a specific key from the cache. Args: key: Key to evict """ if key in self._cache: del self._cache[key] del self._sizes[key] del self._timestamps[key] if self.track_stats: self._evictions += 1
[docs] def _estimate_size(self, value: Any) -> int: """Estimate the size of a value in bytes. Args: value: Value to estimate size of Returns: Estimated size in bytes """ if hasattr(value, "nbytes"): return value.nbytes return len(str(value).encode())
@property def _total_size(self) -> int: """Get the total size of the cache. Returns: Total size in bytes """ return sum(self._sizes.values())
[docs] def _evict(self): """Evict an item based on the eviction policy.""" if not self._cache: return if self.eviction_policy == "lru": # Evict least recently used item key = min(self._timestamps.items(), key=lambda x: x[1])[0] else: # Evict first in first out item key = min(self._timestamps.items(), key=lambda x: x[1])[0] self.evict(key)
def get_available_memory() -> float: """Get the available memory in bytes. Returns: Available memory in bytes """ return psutil.virtual_memory().available