Muon Optimizer Introduction and Overview | Generated by AI

Home 2025.10

Comprehensive Introduction to the Muon Optimizer

The Muon Optimizer is a specialized, second-order-inspired optimization algorithm designed primarily for training linear layers in neural networks (e.g., fully connected or embedding layers), though it can be extended to others. It was originally proposed in late 2024 by researchers including Keller Jordan and Jeremy Bernstein, with roots in geometric optimization techniques like polar initialization and the modular duality framework[1][2]. Zhiling Yang, founder of Moonshot AI and Kimi AI, highlighted Muon in discussions about training their Kimi K2 model—a 1-trillion-parameter large language model (LLM)—where it served as the backbone for efficient, high-rank updates that adapt to the loss landscape’s geometry[3][4]. However, its baseline version suffered from instability (e.g., loss spikes during long training), prompting Moonshot AI to develop MuonClip, an enhanced variant with stability mechanisms like QK-clipping for attention layers[3][2].

Muon stands out for its token efficiency: it requires fewer training tokens than first-order optimizers like AdamW to achieve comparable performance, making it valuable for resource-intensive tasks like LLM pre-training. It aims to approximate second-order methods (e.g., Newton’s method) without their full computational cost, focusing on eigenvalue adaptation via high-rank matrix updates. This is particularly useful in large-scale models where gradients are noisy, as Muon leverages preconditioning inspired by natural gradients and matrix square roots.

Key Principles and Derivation

Advantages and Drawbacks

Muon has influenced the AI optimization landscape, appearing in benchmarks like Scion and discussions on Reddit/X, often praised for its “geometric intuition.” For full derivations, see Jeremy Bernstein’s blog[2]. Now, let’s look at a practical implementation.

Code Example: Implementing Muon Optimizer in PyTorch

Below is a PyTorch implementation of the basic Muon optimizer, adapted from the official repository (https://github.com/KellerJordan/Muon). This is a simplified version for dense linear layers; it includes Newton-Schulz iterations for the preconditioner.

import torch
import torch.nn as nn

class Muon(torch.optim.Optimizer):
    """
    Muon optimizer for linear layers.
    From: https://github.com/KellerJordan/Muon
    """
    def __init__(self, params, lr=1e-3, lr_b=2e-3, b2=0.95, wd=0.0):
        defaults = dict(lr=lr, lr_b=lr_b, b2=b2, wd=wd)
        super().__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            lr_b = group['lr_b']
            b2 = group['b2']
            wd = group['wd']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data.float()
                state = self.state[p]
                if 'momentum' not in state:
                    state['momentum'] = torch.zeros_like(grad)

                # Momentum update
                state['momentum'].mul_(b2).add_(grad)

                # Weight decay
                if wd != 0:
                    p.data.mul_(1 - lr * wd)

                # Muon's orthonormalization (rank adaptation)
                grad_vec = state['momentum'].view(-1, grad.shape[-1])
                p_vec = p.data.view(-1, p.shape[-1])

                # Newton-Schulz for matrix square root approx (simplified)
                G = grad_vec @ grad_vec.t() / grad_vec.shape[0]
                # In full impl, this is iterative; here, approximate with power series
                sqrt_G = torch.sqrt(G + 1e-6 * torch.eye(G.shape[0], device=G.device))

                # Update
                update = grad_vec.t() @ sqrt_G @ grad_vec / sqrt_G.shape[0]
                p.data.sub_(lr_b * update.view(p.shape))

# Example Usage
model = nn.Linear(768, 768)  # Dense layer
optimizer = Muon(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
data = torch.randn(32, 768)
target = torch.randn(32, 768)

for epoch in range(10):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

Notes on Code:

If you need extensions or benchmarks, let me know!

Citations: [1] https://kellerjordan.github.io/posts/muon/ [2] https://jeremybernste.in/writing/deriving-muon [3] https://github.com/KellerJordan/Muon [4] https://github.com/nil0x9/flash-muon [5] https://www.lakernewhouse.com/writing/muon-2 [6] https://medium.com/@kyeg/building-the-muon-optimizer-in-pytorch-a-geometric-approach-to-neural-network-optimization-17f4601be548 [7] https://discuss.huggingface.co/t/tutorial-understanding-and-implementing-the-muon-optimizer/167717 [8] https://keras.io/api/optimizers/muon/


Back

x-ai/grok-code-fast-1

Donate