μ子优化器介绍与概述 | AI生成和翻译

Home 2025.09

Muon优化器全面解析

Muon优化器是一种专为神经网络线性层(如全连接层或嵌入层)训练设计的二阶启发式优化算法,亦可扩展至其他层结构。该算法最初由Keller Jordan、Jeremy Bernstein等研究人员于2024年底提出,其理论根源可追溯至几何优化技术中的极坐标初始化和模对偶框架[1][2]。月之暗面(Moonshot AI)创始人兼Kimi AI首席执行官杨植麟在讨论其千亿参数大模型Kimi K2训练时特别指出,Muon通过适应损失函数几何形态的高秩更新机制,成为模型高效训练的核心支柱[3][4]。然而基础版本存在稳定性缺陷(如长时训练中的损失值突变),促使月之暗面开发了增强版本MuonClip,通过引入注意力层QK裁剪等稳定化机制提升训练鲁棒性[3][2]。

Muon的突出优势在于令牌效率:相较于AdamW等一阶优化器,它以更少的训练令牌量即可达到相当性能,这对LLM预训练等计算密集型任务极具价值。该算法旨在以低于传统二阶方法(如牛顿法)的计算成本实现近似效果,重点通过高秩矩阵更新实现特征值自适应。在大规模模型的梯度噪声环境中,Muon基于自然梯度和矩阵平方根思想的预条件技术展现出独特优势。

核心原理与推导

优势与局限

Muon已影响AI优化领域发展,出现在Scion基准测试和Reddit/X平台讨论中,常因其”几何直观性”受赞誉。完整推导请参阅Jeremy Bernstein的技术博客[2]。下面我们来看具体实现方案。

代码实例:PyTorch实现Muon优化器

以下是根据官方仓库(https://github.com/KellerJordan/Muon)改编的基础Muon优化器PyTorch实现,这是针对稠密线性层的简化版本,包含预条件器的Newton-Schulz迭代。

import torch
import torch.nn as nn

class Muon(torch.optim.Optimizer):
    """
    线性层专用Muon优化器
    改编自:https://github.com/KellerJordan/Muon
    """
    def __init__(self, params, lr=1e-3, lr_b=2e-3, b2=0.95, wd=0.0):
        defaults = dict(lr=lr, lr_b=lr_b, b2=b2, wd=wd)
        super().__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            lr_b = group['lr_b']
            b2 = group['b2']
            wd = group['wd']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data.float()
                state = self.state[p]
                if 'momentum' not in state:
                    state['momentum'] = torch.zeros_like(grad)

                # 动量更新
                state['momentum'].mul_(b2).add_(grad)

                # 权重衰减
                if wd != 0:
                    p.data.mul_(1 - lr * wd)

                # Muon正交化(秩自适应)
                grad_vec = state['momentum'].view(-1, grad.shape[-1])
                p_vec = p.data.view(-1, p.shape[-1])

                # Newton-Schulz矩阵平方根近似(简化版)
                G = grad_vec @ grad_vec.t() / grad_vec.shape[0]
                # 完整实现为迭代过程,此处采用幂级数近似
                sqrt_G = torch.sqrt(G + 1e-6 * torch.eye(G.shape[0], device=G.device))

                # 参数更新
                update = grad_vec.t() @ sqrt_G @ grad_vec / sqrt_G.shape[0]
                p.data.sub_(lr_b * update.view(p.shape))

# 使用示例
model = nn.Linear(768, 768)  # 稠密层
optimizer = Muon(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
data = torch.randn(32, 768)
target = torch.randn(32, 768)

for epoch in range(10):
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

代码说明

如需扩展实现或基准测试资料,欢迎进一步交流!

参考文献: [1] https://kellerjordan.github.io/posts/muon/ [2] https://jeremybernste.in/writing/deriving-muon [3] https://github.com/KellerJordan/Muon [4] https://github.com/nil0x9/flash-muon [5] https://www.lakernewhouse.com/writing/muon-2 [6] https://medium.com/@kyeg/building-the-muon-optimizer-in-pytorch-a-geometric-approach-to-neural-network-optimization-17f4601be548 [7] https://discuss.huggingface.co/t/tutorial-understanding-and-implementing-the-muon-optimizer/167717 [8] https://keras.io/api/optimizers/muon/


Back

x-ai/grok-code-fast-1

Donate