Skip to main content

Building a Nano MoE from Scratch in JAX

· 3 min read

A beginner-friendly deep-dive into how Mixture-of-Experts works, why it matters, and how to build one in pure JAX/Flax.

What is Mixture of Experts?

Imagine you have a team of specialists. Instead of asking every specialist to look at every problem, you have a manager who looks at each problem and says "Expert #2 and Expert #4, you're the best fit for this — handle it."

Why MoE Matters

ModelTotal ParamsActive ParamsExperts
Mixtral 8x7B46.7B12.9B8
GPT-4 (rumored)~1.8T~280B16
DeepSeek-V3671B37B256

More capacity. Same compute. That's the magic of sparse activation.

Our Architecture

NanoMoE is a GPT-style transformer where the FFN in each block is replaced with a MoE layer:

Default Config: 2.4M parameters

ParameterValue
d_model128
n_layers4
n_heads4
n_experts4
top_k2

Top-K Routing — How It Works

  1. Project token to n_experts logits
  2. Select top-K experts
  3. Softmax over selected values
  4. Weighted combination of expert outputs

Load Balancing

Without it, MoE collapses — all tokens go to one expert:

Solution: Auxiliary loss from the Switch Transformer paper:

total_loss = CE_loss + 0.01 × aux_loss

Training Results

We trained on Tiny Shakespeare for 5,000 steps:

StepTrain LossVal Loss
14.234.09
10002.142.08
30001.691.77
50001.541.66

✅ Loss dropped 64% (4.23 → 1.54) ✅ Aux loss stable at ~4.0 (experts balanced) ✅ Minimal overfitting (gap = 0.12)

Try It Yourself

pip install nano-moe-jax
python -c "
from nano_moe import NanoMoEConfig, NanoMoE
import jax, jax.numpy as jnp

model = NanoMoE(config=NanoMoEConfig())
params = model.init(jax.random.PRNGKey(0), jnp.ones((1,32), dtype=jnp.int32))['params']
logits, aux = model.apply({'params': params}, jnp.ones((1,32), dtype=jnp.int32))
print(f'Output: {logits.shape}, Aux loss: {aux:.2f}')
"

📖 Full documentation: carrycooldude.github.io/Nano-MoE-JAX

Star the repo: github.com/carrycooldude/Nano-MoE-JAX