Skip to main content

Training API

nano_moe.train

create_train_state

def create_train_state(rng, model, config) -> TrainState

Initialize model parameters and optimizer state.

Parameters:

ParameterTypeDescription
rngPRNGKeyRandom key for parameter initialization
modelNanoMoEModel instance
configNanoMoEConfigConfiguration

Returns: flax.training.train_state.TrainState with AdamW optimizer.


train_step

@jax.jit
def train_step(state, x, y, rng) -> Tuple[TrainState, loss, ce_loss, aux_loss]

Single JIT-compiled training step.

Parameters:

ParameterTypeDescription
stateTrainStateCurrent parameters + optimizer state
xjnp.ndarray (B, T)Input token batch
yjnp.ndarray (B, T)Target token batch (shifted by 1)
rngPRNGKeyRandom key for dropout

Returns: (updated_state, total_loss, ce_loss, aux_loss)


eval_step

@jax.jit
def eval_step(state, x, y) -> Tuple[loss, ce_loss, aux_loss]

Single JIT-compiled evaluation step (no dropout, no gradient update).


train_loop

def train_loop(model, config, train_data, val_data, rng) -> TrainState

Full training loop with periodic evaluation and logging.

Parameters:

ParameterTypeDescription
modelNanoMoEModel instance
configNanoMoEConfigConfiguration
train_datajnp.ndarrayTraining token sequence
val_datajnp.ndarrayValidation token sequence
rngPRNGKeyRandom key

Returns: Final TrainState with trained parameters.