DeepSpeed
e5de42ce - Fix non-contiguous tensor output from Gram NS for tall matrices

Commit
6 hours ago
Fix non-contiguous tensor output from Gram NS for tall matrices Gram Newton-Schulz produces non-contiguous tensors via .mT for tall weight matrices (e.g., gate_proj/up_proj in LLaMA). This caused downstream grad norm computation (g.data.double()) to be ~1.8x slower due to strided memory access, adding ~75ms to optimizer step time. Add .contiguous() to the Gram NS return path for tall matrices, and ensure muon_update casts back to the original gradient dtype (Gram NS uses fp16 internally while gradients are bf16). Benchmark (Qwen2.5-3B, 2xA100, ZeRO-2, 3 runs avg): Before fix: 945.1ms/step (optimizer: 229.9ms) After fix: 936.6ms/step (optimizer: 204.3ms) Standard NS baseline: 1054.5ms/step Gram NS speedup: 10.4% -> 11.2% Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Author
Parents
Loading