超越牛顿-舒尔茨的极坐标因子 – 快速矩阵逆平方根
Polar Factor Beyond Newton-Schulz – Fast Matrix Inverse Square Root

原始链接: https://jiha-kim.github.io/posts/polar-factor-beyond-newton-schulz-fast-matrix-inverse-square-root/

## Muon 优化器:机器学习的快速极因子计算 Muon 优化器在机器学习中表现出色,能够高效地近似计算矩阵的极因子——一项关键运算,类似于 signSGD 或 Lion。它的目标是计算 **polar(G) = G(GᵀG)⁻¹/²**,适用于高矩阵 G,重点在于速度、数值稳定性(尤其是在 bf16 中)以及在线精度验证。 Muon 通过避免直接 SVD 计算来实现这一点,而是使用仅包含矩形矩阵乘法(GEMM)和较小正方矩阵运算的迭代方式来细化近似值。其核心思想是计算 Gram 矩阵 (GᵀG) 的逆平方根,然后乘以 G。 主要特性包括:一种**Gram 侧逆平方根**方法,利用** minimax 多项式**进行高效迭代,以及基于 Gram 残差的**在线证书**来验证结果的准确性。Jacobi 缩放用于改善谱条件,且不引入偏差。通过对称化、岭回归和重启块来增强稳定性,这些技术借鉴自 Polar Express。 预计算的多项式系数允许根据当前残差进行快速在线选择,在积极迭代和受控收敛之间取得平衡,这对于低精度算术尤其重要。这种方法提供了一种快速、稳定且可认证的近似,适用于大规模机器学习应用。

黑客新闻 新 | 过去 | 评论 | 提问 | 展示 | 招聘 | 提交 登录 超越牛顿-舒尔兹的极因子 – 快速矩阵逆平方根 (jiha-kim.github.io) 4 点 由 ibobev 1 小时前 | 隐藏 | 过去 | 收藏 | 讨论 帮助 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请 YC | 联系 搜索:
相关文章

原文

The Muon optimizer has found huge empirical success in machine learning. It’s essentially signSGD (or Lion by including momentum) for matrices. For the update, we need to approximate the sign function on the singular values of the momentum matrix to compute the polar factor.

Goal: Given \(G\in \mathbb{R}^{m \times n}\) tall (\(m \ge n\)), compute the (column-)orthonormal polar factor

\[ \mathrm{polar}(G):=G(G^\top G)^{-1/2} \]

For the compact SVD \(G=U\Sigma V^\top\), \(\mathrm{polar}(G)=UV^\top\). This is the “directional” component in the polar decomposition \(G=\mathrm{polar}(G) \vert G\vert \), similar to the polar coordinates of a complex number \(z=e^{i\theta}\cdot r\):

\[ \vert G\vert := \sqrt{ G^\top G } \quad \text{("stretch" part: modulus of matrix)} \]

\[ \mathrm{polar}(G)=G\vert G\vert ^{-1} \quad\text{("direction" part: unitary polar factor)}) \]

In Muon, we typically do not need high accuracy, but we do want:

  1. a fast GPU path (mostly GEMMs),
  2. numerical stability in bf16,
  3. a way to certify that \(\sigma_i(U)\) are close to \(1\).

Newton-Schulz/Polar Express iterations: normalize singular values to unit interval \([0,1]\) then directly compute with rectangular GEMMs.

Potential opportunity for \(m \gg n\): compute \((G^\top G)^{-1/2}\) on the small side and multiply once, can refine with full polar steps. This gives some nicer theoretical properties to try, e.g. (precomputed) online coefficient scheduling compared to Polar Express offline coefficients.

Goal

Given \(G \in \mathbb{R}^{m \times n}\) tall (\(m \ge n\)), compute the orthonormal polar factor

\[ \mathrm{polar}(G) := G(G^\top G)^{-1/2}. \]

We want a fast ML-friendly approximation that:

  • uses only 2 rectangular GEMMs (form \(B=G^\top G\), final multiply \(G\widetilde Z\)),
  • does the iterative work on small \(n \times n\) matrices,
  • is stable in bf16 (fp32 accumulate where needed),
  • provides an online certificate that singular values of the returned factor are close to \(1\).

Key idea. Gram-side inverse square root

Let

\[ B := G^\top G \in \mathbb{R}^{n \times n}. \]

Compute \(\widetilde Z \approx B^{-1/2}\) using only \(n\times n\) work, then output

\[ \widetilde U := G \widetilde Z. \]

This is the same structural win that Polar Express exploits for rectangular matrices: form a Gram matrix once, iterate on the small side, then do one final rectangular multiply. Polar Express formalizes this as “Fast Polynomial Iteration for Rectangular Matrices” (Algorithm 4) (Amsel et al., 2025).

What we can certify online (stronger than rectangular direct iterations)

Define the Gram residual

\[ E := \widetilde U^\top \widetilde U - I = \widetilde Z^\top B \widetilde Z - I. \]

If \(\Vert E\Vert _2 \le \eta\), then

\[ \sqrt{1-\eta} \le \sigma_i(\widetilde U) \le \sqrt{1+\eta}. \]

Since \(\Vert E\Vert _2 \le \Vert E\Vert _F\), we can use the cheap sufficient check \(\Vert E\Vert _F \le \eta\) (all on \(n \times n\)). This gives a reliable online proxy for “how safe/aggressive can we be”.

Why we do NOT use AOL here (replace with unbiased Jacobi on \(B\))

Turbo-Muon’s AOL is a column scaling applied to \(G\) (so it changes the target to \(\mathrm{polar}(G S)\) and introduces bias) (Boissin et al., 2025). Since we are already working on the square SPD Gram matrix \(B\), we can get the spectrum-improving benefits without bias using an SPD congruence scaling:

\[ \widetilde B := D B D, \qquad B^{-1/2} = D \, \widetilde B^{-1/2} \, D. \]

This changes conditioning but not the mathematical target (up to numerical error).

Empirically, Jacobi scaling (unit-diagonal) is often the best simple choice:

\[ D := \mathrm{diag}(d), \qquad d_i = (B_{ii}+\epsilon)^{-1/2}. \]

Stability rules. bf16-safe iterations

Polar Express identifies low-precision issues when iterating via Gram-side polynomial compositions (their Algorithm 4) and suggests:

  • add a ridge early to avoid spurious indefiniteness from roundoff,
  • restart compositions to avoid ill-conditioned intermediate factors (Amsel et al., 2025).

We adopt the same philosophy:

  • always symmetrize \(B\) and ridge it,
  • use restart blocks when composing aggressive polynomials,
  • do all small-side iteration in fp32 (or at least fp32 accumulate and residual checks).

Core iteration: minimax-polynomial inverse square root for SPD matrices

Template (“drive the Gram to \(I\)”)

We compute an inverse square root of an SPD matrix \(A\) by maintaining \(Z_k \approx A^{-1/2}\) and driving

\[ S_k := Z_k^\top A Z_k \to I. \]

Update:

\[ Z_{k+1} = Z_k\,q_k(S_k), \]

so eigenvalues evolve as

\[ \lambda \mapsto \lambda' = \lambda\,q_k(\lambda)^2. \]

This matches the standard Newton-style “matrix-multiplication only” inverse-root framework (no factorizations), e.g. in analyses of inverse \(p\)th-root iterations (Guo and Higham, 2006).

Why minimax (Polar Express port)

Polar Express selects per-step polynomials using minimax optimization on an interval to get strong worst-case contraction (Amsel et al., 2025). We port that idea to the SPD eigenvalue map.

For a spectral interval \([\ell,u]\), choose degree-\(d\) polynomial \(q\) by

\[ q^\ast \in \arg\min_{q\in\mathcal{P}_d}\;\max_{\lambda\in[\ell,u]} \left\vert \sqrt{\lambda}\,q(\lambda) - 1\right\vert . \]

If \(\left\vert \sqrt{\lambda}\,q(\lambda)-1\right\vert \le\varepsilon\) on \([\ell,u]\), then

\[ \lambda' = (\sqrt{\lambda}\,q(\lambda))^2 \in [(1-\varepsilon)^2,(1+\varepsilon)^2], \]

giving a clean contraction/interval propagation rule.

We do not solve minimax online; instead we precompute a dense coefficient table offline and select online based on the measured residual.

Offline

Precompute two families:

Phase 1 (global) polynomials:

  • intervals \([\ell,1]\) with \(\ell\) log-spaced (e.g. \(\ell\in\{10^{-4},10^{-3},\dots ,0.5\}\)),
  • minimax \(q_{\ell}\) for each interval.

Phase 2 (local, symmetric-around-1) polynomials:

  • represent \(S = I + R\) and approximate \((I+R)^{-1/2}\),
  • intervals \(r\in[-\rho,\rho]\) with \(\rho\) on a grid (e.g. \(\rho\in\{0.02,0.05,0.1,0.2,0.35,0.5,0.7,0.9\}\)),
  • minimax \(p_{\rho}\) approximating \((1+r)^{-1/2}\) on \([-\rho,\rho]\).

Optionally impose stability constraints in the offline solve (recommended for bf16):

  • \(q(\lambda) > 0\) on the interval (SPD preservation),
  • cap overshoot: ensure \(\lambda q(\lambda)^2\) stays in a controlled range,
  • limit slope near \(1\) to avoid local amplification.

Online selection

At each step compute

\[ S = Z^\top A Z,\qquad \delta_S := \Vert S-I\Vert _F. \]

Then \(\Vert S-I\Vert _2 \le \delta_S\), so

\[ \lambda(S)\subset[1-\delta_S,\,1+\delta_S]. \]

Pick a slightly inflated design radius

\[ \rho_{\text{design}} := \gamma\,\delta_S,\qquad \gamma\in[1.1,1.5], \]

and choose the nearest polynomial \(p_{\rho_{\text{design}}}\) (Phase 2) or, in Phase 1, choose a conservative \(\ell\) schedule.


Two-phase scheme (safe globalization, aggressive local polish)

Phase 0: Form \(B\) and apply unbiased preconditioning

  1. \(B \leftarrow G^\top G\) (fp32 accumulate)
  2. \(B \leftarrow \tfrac12(B+B^\top)\)
  3. Ridge: \(B \leftarrow B + \delta I\)
  4. Jacobi: \(D_{ii} \leftarrow (B_{ii}+\epsilon)^{-1/2}\)
  5. \(\widetilde B \leftarrow DBD\) (elementwise scaling: \(\widetilde B_{ij}=d_i B_{ij} d_j\))

Phase 1: Safe scaling to \((0,1]\) and global minimax steps

  1. Upper bound \(\Lambda \ge \lambda_{\max}(\widetilde B)\) (Gershgorin \(\Vert \widetilde B\Vert _\infty\))
  2. Scale:

    \[ \alpha := \Lambda^{-1/2},\qquad A := \alpha^2 \widetilde B \]

    so \(\lambda(A)\subset(0,1]\)

  3. Initialize \(Z \leftarrow I\)
  4. Repeat in restart blocks (\(T_{\text{block}}\in\{2,3\}\)):
    • \(S \leftarrow Z^\top A Z\)
    • if \(\Vert S-I\Vert _F \le \rho_{\text{switch}}\) (e.g. \(0.5\)): break
    • choose \(q_\ell\) (table lookup for a conservative \(\ell\)) and apply:

      \[ Z \leftarrow Z\,q_\ell(S) \]

    • restart: recompute \(S\) in fp32 and reselect coefficients

Phase 2: Local symmetric-around-1 steps (aggressive but certified)

Now \(\Vert S-I\Vert _F\) is small enough that we can safely use symmetric intervals around \(1\).

Repeat for \(t=1,2\) (often 1 is enough):

  • \(S \leftarrow Z^\top A Z\)
  • \(\delta_S \leftarrow \Vert S-I\Vert _F\)
  • if \(\delta_S \le \eta\): stop
  • \(\rho_{\text{design}} \leftarrow \gamma\delta_S\)
  • lookup \(p_{\rho_{\text{design}}}\) and apply:

    \[ Z \leftarrow Z\,p_{\rho_{\text{design}}}(S-I) \]

Finish: map back to \(B^{-1/2}\) and form \(\widetilde U\)

  1. \(\widetilde B^{-1/2} \approx \alpha Z\)
  2. Map back:

    \[ \widetilde Z := B^{-1/2} \approx D(\alpha Z)D \]

  3. Output:

    \[ \widetilde U = G\widetilde Z \]

Certification and optional polish

Compute

\[ E = \widetilde Z^\top B \widetilde Z - I \]

and check \(\Vert E\Vert _F \le \eta\).


Restarts (important for bf16)

Use short composition blocks (\(T_{\text{block}}\in\{2,3\}\)), then recompute \(S\) and reselect coefficients. This mirrors Polar Express’s practical stabilization for Gram-side rectangular acceleration (Amsel et al., 2025).


Unbiased, minimax, Jacobi, online selection

Input: \(G\), ridge \(\delta\), Jacobi eps \(\epsilon\), tol \(\eta\), switch \(\rho_{\text{switch}}\), inflate \(\gamma\), coefficient tables

  1. \(B \leftarrow G^\top G\) (fp32 accumulate)
  2. \(B \leftarrow \tfrac12(B+B^\top) + \delta I\)
  3. \(d_i \leftarrow (B_{ii}+\epsilon)^{-1/2}\), \(D=\mathrm{diag}(d)\)
  4. \(\widetilde B \leftarrow DBD\)
  5. \(\Lambda \leftarrow\) upper bound on \(\lambda_{\max}(\widetilde B)\)
  6. \(\alpha \leftarrow \Lambda^{-1/2}\), \(A \leftarrow \alpha^2 \widetilde B\)
  7. \(Z \leftarrow I\)

Phase 1:

  1. repeat (restart blocks): a. \(S \leftarrow Z^\top A Z\) b. if \(\Vert S-I\Vert _F \le \rho_{\text{switch}}\): break c. select minimax \(q_\ell\) for a conservative \([\ell,1]\) d. \(Z \leftarrow Z\,q_\ell(S)\)

Phase 2:

  1. for \(t=1,2\):
    1. \(S \leftarrow Z^\top A Z\)
    2. \(\delta_S \leftarrow \Vert S-I\Vert _F\)
    3. if \(\delta_S \le \eta\): break
    4. \(\rho_{\text{design}} \leftarrow \gamma\delta_S\)
    5. select minimax \(p_{\rho_{\text{design}}}\)
    6. \(Z \leftarrow Z\,p_{\rho_{\text{design}}}(S-I)\)

Finish:

  1. \(Z_{\widetilde B} \leftarrow \alpha Z\) (approx \(\widetilde B^{-1/2}\))
  2. \(\widetilde Z \leftarrow D Z_{\widetilde B} D\) (approx \(B^{-1/2}\))
  3. \(\widetilde U \leftarrow G\widetilde Z\)
  4. \(E \leftarrow \widetilde Z^\top B \widetilde Z - I\); if \(\Vert E\Vert _F > \eta\), do one more Phase-2 step

Return: \(\widetilde U\)


What “dense coefficients” buys you

A dense coefficient grid lets you select a nearly optimal minimax polynomial for the actual measured residual each step (interval-driven updates), matching the spirit of Polar Express (Amsel et al., 2025), but with a stronger online interval proxy because \(S\) is small SPD.

It improves:

  • early contraction when the spectrum is wide,
  • iteration count when the spectrum is already tight,
  • stability: you can inflate the interval by \(\gamma\) and still stay close to minimax-optimal.

This is the clean way to be “more aggressive” while controlling effective convergence radius in bf16.

联系我们 contact @ memedata.com