Skip to main content

🧠 NanoMoE

A lightweight Mixture-of-Experts language model in JAX/Flax

🔀 Sparse MoE Routing

Top-K gating with softmax weighting activates only 2 of 4 experts per token — more capacity with constant compute.

⚡ Pure JAX/Flax

Built from scratch with jax.vmap for parallel experts, jax.jit for XLA compilation, and jax.grad for automatic differentiation.

📖 Educational & Hackable

2.4M parameters, ~500 lines of code, fully tested. Perfect for learning MoE internals and experimenting with new ideas.

Install in One Command

pip install nano-moe-jax