<?xml version="1.0" encoding="utf-8"?><?xml-stylesheet type="text/xsl" href="atom.xsl"?>
<feed xmlns="http://www.w3.org/2005/Atom">
    <id>https://carrycooldude.github.io/Nano-MoE-JAX/blog</id>
    <title>NanoMoE Blog</title>
    <updated>2026-02-24T00:00:00.000Z</updated>
    <generator>https://github.com/jpmonette/feed</generator>
    <link rel="alternate" href="https://carrycooldude.github.io/Nano-MoE-JAX/blog"/>
    <subtitle>NanoMoE Blog</subtitle>
    <icon>https://carrycooldude.github.io/Nano-MoE-JAX/img/favicon.ico</icon>
    <entry>
        <title type="html"><![CDATA[Building a Nano MoE from Scratch in JAX]]></title>
        <id>https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive</id>
        <link href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive"/>
        <updated>2026-02-24T00:00:00.000Z</updated>
        <summary type="html"><![CDATA[A beginner-friendly deep-dive into how Mixture-of-Experts works, why it matters, and how to build one in pure JAX/Flax.]]></summary>
        <content type="html"><![CDATA[<p>A beginner-friendly deep-dive into how Mixture-of-Experts works, why it matters, and how to build one in pure JAX/Flax.</p>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="what-is-mixture-of-experts">What is Mixture of Experts?<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#what-is-mixture-of-experts" class="hash-link" aria-label="Direct link to What is Mixture of Experts?" title="Direct link to What is Mixture of Experts?" translate="no">​</a></h2>
<p>Imagine you have a team of specialists. Instead of asking <em>every</em> specialist to look at every problem, you have a <strong>manager</strong> who looks at each problem and says <em>"Expert #2 and Expert #4, you're the best fit for this — handle it."</em></p>
<!-- -->
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="why-moe-matters">Why MoE Matters<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#why-moe-matters" class="hash-link" aria-label="Direct link to Why MoE Matters" title="Direct link to Why MoE Matters" translate="no">​</a></h2>
<table><thead><tr><th>Model</th><th>Total Params</th><th>Active Params</th><th>Experts</th></tr></thead><tbody><tr><td>Mixtral 8x7B</td><td>46.7B</td><td>12.9B</td><td>8</td></tr><tr><td>GPT-4 (rumored)</td><td>~1.8T</td><td>~280B</td><td>16</td></tr><tr><td>DeepSeek-V3</td><td>671B</td><td>37B</td><td>256</td></tr></tbody></table>
<p><strong>More capacity. Same compute.</strong> That's the magic of sparse activation.</p>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="our-architecture">Our Architecture<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#our-architecture" class="hash-link" aria-label="Direct link to Our Architecture" title="Direct link to Our Architecture" translate="no">​</a></h2>
<p>NanoMoE is a GPT-style transformer where the FFN in each block is replaced with a MoE layer:</p>
<!-- -->
<h3 class="anchor anchorTargetStickyNavbar_Vzrq" id="default-config-24m-parameters">Default Config: 2.4M parameters<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#default-config-24m-parameters" class="hash-link" aria-label="Direct link to Default Config: 2.4M parameters" title="Direct link to Default Config: 2.4M parameters" translate="no">​</a></h3>
<table><thead><tr><th>Parameter</th><th>Value</th></tr></thead><tbody><tr><td><code>d_model</code></td><td>128</td></tr><tr><td><code>n_layers</code></td><td>4</td></tr><tr><td><code>n_heads</code></td><td>4</td></tr><tr><td><code>n_experts</code></td><td>4</td></tr><tr><td><code>top_k</code></td><td>2</td></tr></tbody></table>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="top-k-routing--how-it-works">Top-K Routing — How It Works<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#top-k-routing--how-it-works" class="hash-link" aria-label="Direct link to Top-K Routing — How It Works" title="Direct link to Top-K Routing — How It Works" translate="no">​</a></h2>
<!-- -->
<ol>
<li class="">Project token to <code>n_experts</code> logits</li>
<li class="">Select top-K experts</li>
<li class="">Softmax over selected values</li>
<li class="">Weighted combination of expert outputs</li>
</ol>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="load-balancing">Load Balancing<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#load-balancing" class="hash-link" aria-label="Direct link to Load Balancing" title="Direct link to Load Balancing" translate="no">​</a></h2>
<p>Without it, MoE collapses — all tokens go to one expert:</p>
<!-- -->
<p><strong>Solution:</strong> Auxiliary loss from the Switch Transformer paper:</p>
<div class="language-text codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#F8F8F2;--prism-background-color:#282A36"><div class="codeBlockContent_QJqH"><pre tabindex="0" class="prism-code language-text codeBlock_bY9V thin-scrollbar" style="color:#F8F8F2;background-color:#282A36"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#F8F8F2"><span class="token plain">total_loss = CE_loss + 0.01 × aux_loss</span><br></span></code></pre></div></div>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="training-results">Training Results<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#training-results" class="hash-link" aria-label="Direct link to Training Results" title="Direct link to Training Results" translate="no">​</a></h2>
<p>We trained on Tiny Shakespeare for 5,000 steps:</p>
<table><thead><tr><th>Step</th><th>Train Loss</th><th>Val Loss</th></tr></thead><tbody><tr><td>1</td><td>4.23</td><td>4.09</td></tr><tr><td>1000</td><td>2.14</td><td>2.08</td></tr><tr><td>3000</td><td>1.69</td><td>1.77</td></tr><tr><td>5000</td><td><strong>1.54</strong></td><td><strong>1.66</strong></td></tr></tbody></table>
<p>✅ Loss dropped <strong>64%</strong> (4.23 → 1.54)
✅ Aux loss stable at ~4.0 (experts balanced)
✅ Minimal overfitting (gap = 0.12)</p>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="try-it-yourself">Try It Yourself<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#try-it-yourself" class="hash-link" aria-label="Direct link to Try It Yourself" title="Direct link to Try It Yourself" translate="no">​</a></h2>
<div class="language-bash codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#F8F8F2;--prism-background-color:#282A36"><div class="codeBlockContent_QJqH"><pre tabindex="0" class="prism-code language-bash codeBlock_bY9V thin-scrollbar" style="color:#F8F8F2;background-color:#282A36"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#F8F8F2"><span class="token plain">pip </span><span class="token function" style="color:rgb(80, 250, 123)">install</span><span class="token plain"> nano-moe-jax</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token plain">python </span><span class="token parameter variable" style="color:rgb(189, 147, 249);font-style:italic">-c</span><span class="token plain"> </span><span class="token string" style="color:rgb(255, 121, 198)">"</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">from nano_moe import NanoMoEConfig, NanoMoE</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">import jax, jax.numpy as jnp</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)"></span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">model = NanoMoE(config=NanoMoEConfig())</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">params = model.init(jax.random.PRNGKey(0), jnp.ones</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">((</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">1</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">)</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic"> dtype</span><span class="token string variable operator" style="color:rgb(189, 147, 249);font-style:italic">=</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic">jnp.int32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">))</span><span class="token string" style="color:rgb(255, 121, 198)">['params']</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">logits, aux = model.apply({'params': params}, jnp.ones</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">((</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">1</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">)</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic"> dtype</span><span class="token string variable operator" style="color:rgb(189, 147, 249);font-style:italic">=</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic">jnp.int32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">))</span><span class="token string" style="color:rgb(255, 121, 198)"></span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">print(f'Output: {logits.shape}, Aux loss: {aux:.2f}')</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">"</span><br></span></code></pre></div></div>
<p>📖 <strong>Full documentation: <a href="https://carrycooldude.github.io/Nano-MoE-JAX/" target="_blank" rel="noopener noreferrer" class="">carrycooldude.github.io/Nano-MoE-JAX</a></strong></p>
<p>⭐ <strong>Star the repo: <a href="https://github.com/carrycooldude/Nano-MoE-JAX" target="_blank" rel="noopener noreferrer" class="">github.com/carrycooldude/Nano-MoE-JAX</a></strong></p>]]></content>
        <author>
            <name>carrycooldude</name>
            <uri>https://github.com/carrycooldude</uri>
        </author>
        <category label="moe" term="moe"/>
        <category label="jax" term="jax"/>
        <category label="flax" term="flax"/>
        <category label="transformers" term="transformers"/>
        <category label="deep-learning" term="deep-learning"/>
    </entry>
</feed>