Add Gram Newton-Schulz iteration for Muon optimizer
Integrate Gram Newton-Schulz (Gram NS) as the default orthogonalization
method for Muon, with a configurable ns_method switch to fall back to
standard NS when needed (e.g., for debugging convergence issues).
Gram NS iterates on the small square Gram matrix R = X @ X.T (n x n)
instead of the full rectangular X (n x m), reducing FLOPs by ~50% for
typical transformer weight matrices (aspect ratio ~5). It uses fp16
instead of bf16 for better numerical precision at the same compute cost,
with a restart at iteration 2 for half-precision stability.
Benchmark results on A100:
- (2048, 11059): 2.25x GPU speedup, 1.85x CPU speedup
- (3584, 19353): 2.07x GPU speedup, 1.35x CPU speedup
- Falls back to standard NS for square matrices (no FLOP advantage)
Usage: set ns_method in DeepSpeed config:
{"optimizer": {"type": "muon", "params": {"ns_method": "gram"}}}
Use "standard" to disable Gram NS and revert to original behavior.
Reference: https://arxiv.org/abs/2503.02022
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>