🔀 Sparse MoE Routing
Top-K gating with softmax weighting activates only 2 of 4 experts per token — more capacity with constant compute.
⚡ Pure JAX/Flax
Built from scratch with jax.vmap for parallel experts, jax.jit for XLA compilation, and jax.grad for automatic differentiation.
📖 Educational & Hackable
2.4M parameters, ~500 lines of code, fully tested. Perfect for learning MoE internals and experimenting with new ideas.
Install in One Command
pip install nano-moe-jax