Building a Nano MoE from Scratch in JAX
A beginner-friendly deep-dive into how Mixture-of-Experts works, why it matters, and how to build one in pure JAX/Flax.
What is Mixture of Experts?
Imagine you have a team of specialists. Instead of asking every specialist to look at every problem, you have a manager who looks at each problem and says "Expert #2 and Expert #4, you're the best fit for this — handle it."
Why MoE Matters
| Model | Total Params | Active Params | Experts |
|---|---|---|---|
| Mixtral 8x7B | 46.7B | 12.9B | 8 |
| GPT-4 (rumored) | ~1.8T | ~280B | 16 |
| DeepSeek-V3 | 671B | 37B | 256 |
More capacity. Same compute. That's the magic of sparse activation.
Our Architecture
NanoMoE is a GPT-style transformer where the FFN in each block is replaced with a MoE layer:
Default Config: 2.4M parameters
| Parameter | Value |
|---|---|
d_model | 128 |
n_layers | 4 |
n_heads | 4 |
n_experts | 4 |
top_k | 2 |
Top-K Routing — How It Works
- Project token to
n_expertslogits - Select top-K experts
- Softmax over selected values
- Weighted combination of expert outputs
Load Balancing
Without it, MoE collapses — all tokens go to one expert:
Solution: Auxiliary loss from the Switch Transformer paper:
total_loss = CE_loss + 0.01 × aux_loss
Training Results
We trained on Tiny Shakespeare for 5,000 steps:
| Step | Train Loss | Val Loss |
|---|---|---|
| 1 | 4.23 | 4.09 |
| 1000 | 2.14 | 2.08 |
| 3000 | 1.69 | 1.77 |
| 5000 | 1.54 | 1.66 |
✅ Loss dropped 64% (4.23 → 1.54) ✅ Aux loss stable at ~4.0 (experts balanced) ✅ Minimal overfitting (gap = 0.12)
Try It Yourself
pip install nano-moe-jax
python -c "
from nano_moe import NanoMoEConfig, NanoMoE
import jax, jax.numpy as jnp
model = NanoMoE(config=NanoMoEConfig())
params = model.init(jax.random.PRNGKey(0), jnp.ones((1,32), dtype=jnp.int32))['params']
logits, aux = model.apply({'params': params}, jnp.ones((1,32), dtype=jnp.int32))
print(f'Output: {logits.shape}, Aux loss: {aux:.2f}')
"
📖 Full documentation: carrycooldude.github.io/Nano-MoE-JAX
⭐ Star the repo: github.com/carrycooldude/Nano-MoE-JAX
