Skip to main content

NanoMoE — Mixture-of-Experts in JAX

A lightweight, educational Mixture-of-Experts (MoE) GPT-style language model built from scratch in JAX / Flax.

Inspired by nanoGPT, NanoMoE replaces the standard feed-forward network in each transformer block with a sparse MoE layer — a set of parallel expert networks with a learned router that activates only the top-K experts per token.

What is Mixture of Experts?

Imagine you have a team of specialists. Instead of asking every specialist to look at every problem, you have a manager who looks at each problem and says "Expert #2 and Expert #4, you're the best fit for this — handle it."

That's exactly how Mixture of Experts works in neural networks:

  • Experts = independent neural networks (typically small feed-forward networks)
  • Router/Gate = a learned module that decides which experts to use for each input
  • Sparse activation = only a subset of experts run for each token
Key Insight

In a standard transformer, every parameter is used for every token. In MoE, the model has many more total parameters, but only activates a fraction of them per token. This means more capacity without a proportional increase in compute cost.

Why MoE Matters

MoE is the architecture behind some of the most powerful AI models today:

ModelTotal ParamsActive ParamsExperts
Mixtral 8x7B46.7B12.9B8
GPT-4 (rumored)~1.8T~280B16
Switch Transformer1.6T~100B128
DeepSeek-V3671B37B256

Total parameters are 3–16× larger than active parameters — you get a bigger model that runs as fast as a smaller one.

What's in This Documentation?

Default Configuration

ParameterValueDescription
d_model128Hidden dimension
n_layers4Transformer blocks
n_heads4Attention heads
d_ff512Expert FFN inner dimension
n_experts4Experts per MoE layer
top_k2Active experts per token
block_size128Context window
vocab_size65Character vocabulary

Total parameters: ~2.4M — small enough to train on a laptop CPU!

Install

pip install nano-moe-jax

Or from source:

git clone https://github.com/carrycooldude/MoE-JAX.git
cd MoE-JAX && pip install -e .