Skip to main content

Full NanoMoE Model

The NanoMoE class puts everything together: embeddings → stacked transformer blocks → language model head.

Complete Data Flow

Two Outputs

The model's __call__ returns two things:

logits, aux_loss = model.apply({"params": params}, input_tokens)
OutputShapeDescription
logits(batch, seq_len, vocab_size)Raw predictions for next token
aux_lossscalarMean auxiliary load-balancing loss across all layers

Autoregressive Generation

After training, the model generates text one token at a time:

The generate() method:

def generate(self, params, rng, prompt, max_new_tokens, temperature=0.8, top_k=40):
tokens = prompt
for _ in range(max_new_tokens):
# Crop to context window
context = tokens[:, -block_size:]

# Get predictions
logits, _ = self.apply({"params": params}, context)

# Sample next token
logits = logits[:, -1, :] / temperature
top_vals, _ = jax.lax.top_k(logits, k=top_k)
logits = jnp.where(logits < top_vals[:, -1:], -1e9, logits)
next_token = jax.random.categorical(rng, logits)

tokens = jnp.concatenate([tokens, next_token[:, None]], axis=1)
return tokens

Generation Parameters

ParameterDefaultEffect
temperature0.8Lower = more deterministic, higher = more creative
top_k40Only sample from the top K most likely tokens
max_new_tokens500Maximum tokens to generate