Skip to main content

Utilities API

nano_moe.utils

count_params

def count_params(params) -> int

Count total number of trainable parameters in a Flax parameter tree.

from nano_moe.utils import count_params
n = count_params(params)
print(f"Parameters: {n:,}") # Parameters: 2,409,025

get_batch

def get_batch(data, batch_size, block_size, rng) -> Tuple[x, y]

Sample a random batch of sequences from the data.

Parameters:

ParameterTypeDescription
datajnp.ndarrayFull token sequence
batch_sizeintNumber of sequences
block_sizeintSequence length
rngPRNGKeyRandom key

Returns: (x, y) where y is x shifted by 1 position (next-token targets).


load_shakespeare

def load_shakespeare(data_dir="data") -> Tuple[train_data, val_data, encode, decode, vocab_size]

Download Tiny Shakespeare, split into train/val, and create encode/decode functions.

Returns:

ReturnTypeDescription
train_datajnp.ndarray~1M training tokens
val_datajnp.ndarray~111K validation tokens
encodeCallablestr → List[int]
decodeCallableList[int] → str
vocab_sizeint65 unique characters