Skip to main content

Model API

nano_moe.model

NanoMoE

class NanoMoE(nn.Module):
config: NanoMoEConfig

The complete NanoMoE language model: embeddings → transformer blocks → LM head.

__call__(x, deterministic=True)

Forward pass through the model.

Parameters:

ParameterTypeDescription
xjnp.ndarray of shape (B, T)Integer token indices
deterministicboolIf True, disable dropout (for evaluation)

Returns: Tuple[logits, aux_loss]

ReturnShapeDescription
logits(B, T, vocab_size)Raw logits for next-token prediction
aux_lossscalarMean auxiliary loss across all transformer blocks

Example:

import jax
import jax.numpy as jnp
from nano_moe import NanoMoEConfig, NanoMoE

config = NanoMoEConfig(vocab_size=65)
model = NanoMoE(config=config)

rng = jax.random.PRNGKey(0)
tokens = jnp.ones((1, 32), dtype=jnp.int32)
params = model.init(rng, tokens)["params"]

logits, aux_loss = model.apply({"params": params}, tokens)
# logits.shape: (1, 32, 65)
# aux_loss: scalar

generate(params, rng, prompt, max_new_tokens, temperature=0.8, top_k=40)

Autoregressive text generation with temperature scaling and top-k filtering.

Parameters:

ParameterTypeDefaultDescription
paramsdictModel parameters
rngPRNGKeyRandom key for sampling
promptjnp.ndarray (1, T)Starting token sequence
max_new_tokensintNumber of tokens to generate
temperaturefloat0.8Sampling temperature (lower = more deterministic)
top_kint40Number of top tokens to sample from

Returns: jnp.ndarray of shape (1, T + max_new_tokens) — the prompt concatenated with generated tokens.

Example:

prompt = jnp.array([[0]])  # newline character
generated = model.generate(params, rng, prompt, max_new_tokens=200)
text = decode(generated[0].tolist())
print(text)