<?xml version="1.0" encoding="utf-8"?><?xml-stylesheet type="text/xsl" href="rss.xsl"?>
<rss version="2.0" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/">
    <channel>
        <title>NanoMoE Blog</title>
        <link>https://carrycooldude.github.io/Nano-MoE-JAX/blog</link>
        <description>NanoMoE Blog</description>
        <lastBuildDate>Tue, 24 Feb 2026 00:00:00 GMT</lastBuildDate>
        <docs>https://validator.w3.org/feed/docs/rss2.html</docs>
        <generator>https://github.com/jpmonette/feed</generator>
        <language>en</language>
        <item>
            <title><![CDATA[Building a Nano MoE from Scratch in JAX]]></title>
            <link>https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive</link>
            <guid>https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive</guid>
            <pubDate>Tue, 24 Feb 2026 00:00:00 GMT</pubDate>
            <description><![CDATA[A beginner-friendly deep-dive into how Mixture-of-Experts works, why it matters, and how to build one in pure JAX/Flax.]]></description>
            <content:encoded><![CDATA[<p>A beginner-friendly deep-dive into how Mixture-of-Experts works, why it matters, and how to build one in pure JAX/Flax.</p>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="what-is-mixture-of-experts">What is Mixture of Experts?<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#what-is-mixture-of-experts" class="hash-link" aria-label="Direct link to What is Mixture of Experts?" title="Direct link to What is Mixture of Experts?" translate="no">​</a></h2>
<p>Imagine you have a team of specialists. Instead of asking <em>every</em> specialist to look at every problem, you have a <strong>manager</strong> who looks at each problem and says <em>"Expert #2 and Expert #4, you're the best fit for this — handle it."</em></p>
<!-- -->
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="why-moe-matters">Why MoE Matters<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#why-moe-matters" class="hash-link" aria-label="Direct link to Why MoE Matters" title="Direct link to Why MoE Matters" translate="no">​</a></h2>
<table><thead><tr><th>Model</th><th>Total Params</th><th>Active Params</th><th>Experts</th></tr></thead><tbody><tr><td>Mixtral 8x7B</td><td>46.7B</td><td>12.9B</td><td>8</td></tr><tr><td>GPT-4 (rumored)</td><td>~1.8T</td><td>~280B</td><td>16</td></tr><tr><td>DeepSeek-V3</td><td>671B</td><td>37B</td><td>256</td></tr></tbody></table>
<p><strong>More capacity. Same compute.</strong> That's the magic of sparse activation.</p>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="our-architecture">Our Architecture<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#our-architecture" class="hash-link" aria-label="Direct link to Our Architecture" title="Direct link to Our Architecture" translate="no">​</a></h2>
<p>NanoMoE is a GPT-style transformer where the FFN in each block is replaced with a MoE layer:</p>
<!-- -->
<h3 class="anchor anchorTargetStickyNavbar_Vzrq" id="default-config-24m-parameters">Default Config: 2.4M parameters<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#default-config-24m-parameters" class="hash-link" aria-label="Direct link to Default Config: 2.4M parameters" title="Direct link to Default Config: 2.4M parameters" translate="no">​</a></h3>
<table><thead><tr><th>Parameter</th><th>Value</th></tr></thead><tbody><tr><td><code>d_model</code></td><td>128</td></tr><tr><td><code>n_layers</code></td><td>4</td></tr><tr><td><code>n_heads</code></td><td>4</td></tr><tr><td><code>n_experts</code></td><td>4</td></tr><tr><td><code>top_k</code></td><td>2</td></tr></tbody></table>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="top-k-routing--how-it-works">Top-K Routing — How It Works<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#top-k-routing--how-it-works" class="hash-link" aria-label="Direct link to Top-K Routing — How It Works" title="Direct link to Top-K Routing — How It Works" translate="no">​</a></h2>
<!-- -->
<ol>
<li class="">Project token to <code>n_experts</code> logits</li>
<li class="">Select top-K experts</li>
<li class="">Softmax over selected values</li>
<li class="">Weighted combination of expert outputs</li>
</ol>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="load-balancing">Load Balancing<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#load-balancing" class="hash-link" aria-label="Direct link to Load Balancing" title="Direct link to Load Balancing" translate="no">​</a></h2>
<p>Without it, MoE collapses — all tokens go to one expert:</p>
<!-- -->
<p><strong>Solution:</strong> Auxiliary loss from the Switch Transformer paper:</p>
<div class="language-text codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#F8F8F2;--prism-background-color:#282A36"><div class="codeBlockContent_QJqH"><pre tabindex="0" class="prism-code language-text codeBlock_bY9V thin-scrollbar" style="color:#F8F8F2;background-color:#282A36"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#F8F8F2"><span class="token plain">total_loss = CE_loss + 0.01 × aux_loss</span><br></span></code></pre></div></div>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="training-results">Training Results<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#training-results" class="hash-link" aria-label="Direct link to Training Results" title="Direct link to Training Results" translate="no">​</a></h2>
<p>We trained on Tiny Shakespeare for 5,000 steps:</p>
<table><thead><tr><th>Step</th><th>Train Loss</th><th>Val Loss</th></tr></thead><tbody><tr><td>1</td><td>4.23</td><td>4.09</td></tr><tr><td>1000</td><td>2.14</td><td>2.08</td></tr><tr><td>3000</td><td>1.69</td><td>1.77</td></tr><tr><td>5000</td><td><strong>1.54</strong></td><td><strong>1.66</strong></td></tr></tbody></table>
<p>✅ Loss dropped <strong>64%</strong> (4.23 → 1.54)
✅ Aux loss stable at ~4.0 (experts balanced)
✅ Minimal overfitting (gap = 0.12)</p>
<h2 class="anchor anchorTargetStickyNavbar_Vzrq" id="try-it-yourself">Try It Yourself<a href="https://carrycooldude.github.io/Nano-MoE-JAX/blog/nano-moe-deep-dive#try-it-yourself" class="hash-link" aria-label="Direct link to Try It Yourself" title="Direct link to Try It Yourself" translate="no">​</a></h2>
<div class="language-bash codeBlockContainer_Ckt0 theme-code-block" style="--prism-color:#F8F8F2;--prism-background-color:#282A36"><div class="codeBlockContent_QJqH"><pre tabindex="0" class="prism-code language-bash codeBlock_bY9V thin-scrollbar" style="color:#F8F8F2;background-color:#282A36"><code class="codeBlockLines_e6Vv"><span class="token-line" style="color:#F8F8F2"><span class="token plain">pip </span><span class="token function" style="color:rgb(80, 250, 123)">install</span><span class="token plain"> nano-moe-jax</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token plain">python </span><span class="token parameter variable" style="color:rgb(189, 147, 249);font-style:italic">-c</span><span class="token plain"> </span><span class="token string" style="color:rgb(255, 121, 198)">"</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">from nano_moe import NanoMoEConfig, NanoMoE</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">import jax, jax.numpy as jnp</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)"></span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">model = NanoMoE(config=NanoMoEConfig())</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">params = model.init(jax.random.PRNGKey(0), jnp.ones</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">((</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">1</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">)</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic"> dtype</span><span class="token string variable operator" style="color:rgb(189, 147, 249);font-style:italic">=</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic">jnp.int32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">))</span><span class="token string" style="color:rgb(255, 121, 198)">['params']</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">logits, aux = model.apply({'params': params}, jnp.ones</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">((</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">1</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable number" style="color:rgb(189, 147, 249);font-style:italic">32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">)</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">,</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic"> dtype</span><span class="token string variable operator" style="color:rgb(189, 147, 249);font-style:italic">=</span><span class="token string variable" style="color:rgb(189, 147, 249);font-style:italic">jnp.int32</span><span class="token string variable punctuation" style="color:rgb(248, 248, 242);font-style:italic">))</span><span class="token string" style="color:rgb(255, 121, 198)"></span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">print(f'Output: {logits.shape}, Aux loss: {aux:.2f}')</span><br></span><span class="token-line" style="color:#F8F8F2"><span class="token string" style="color:rgb(255, 121, 198)">"</span><br></span></code></pre></div></div>
<p>📖 <strong>Full documentation: <a href="https://carrycooldude.github.io/Nano-MoE-JAX/" target="_blank" rel="noopener noreferrer" class="">carrycooldude.github.io/Nano-MoE-JAX</a></strong></p>
<p>⭐ <strong>Star the repo: <a href="https://github.com/carrycooldude/Nano-MoE-JAX" target="_blank" rel="noopener noreferrer" class="">github.com/carrycooldude/Nano-MoE-JAX</a></strong></p>]]></content:encoded>
            <category>moe</category>
            <category>jax</category>
            <category>flax</category>
            <category>transformers</category>
            <category>deep-learning</category>
        </item>
    </channel>
</rss>