fix(zero3): use current_stream() instead of default_stream() for grad… (#7898)
fix(zero3): use current_stream() instead of default_stream() for
gradient reduce-scatter sync
## Problem
DeepSpeed ZeRO Stage 3 produces NaN values in weight partitions after
the optimizer step when used with PyTorch 2.10+ and bf16 training. The
NaN appears stochastically across ranks and affects different layers on
each run (typically attention projections and MLP weights). Training
works correctly with identical configuration on older PyTorch versions.
## Root Cause
In `__add_grad_to_ipg_bucket` (stage3.py), the reduce-and-partition
stream synchronizes with the **default** CUDA stream before copying
gradients:
```python
self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream())
```
However, PyTorch's autograd engine executes each backward op on the
**same stream its forward op used**. In ZeRO-3, forward ops involve
all-gather operations that may execute on non-default streams. When the
corresponding backward ops produce gradients on those non-default
streams, `wait_stream(default_stream)` creates a dependency on the wrong
stream — the reduce-scatter proceeds before the gradient CUDA kernel has
finished writing, reading uninitialized memory from the gradient buffer.
This was not observable on older PyTorch versions because the autograd
engine's stream assignment behavior was more conservative. PyTorch 2.10
introduced changes to autograd stream handling that make this race
condition reliably trigger when gradient magnitudes are large enough for
the resulting NaN to be distinguishable from valid values.
## Minimal Reproduction
1. Configure ZeRO Stage 3 with bf16, `overlap_comm: true`,
`contiguous_gradients: true`
2. Train any model where backward produces gradients on non-default CUDA
streams (e.g., multi-GPU training with PyTorch 2.10+)
3. After the first optimizer step, some weight partitions contain NaN on
a subset of ranks
4. The affected layers and ranks vary between runs (stochastic, depends
on CUDA kernel scheduling)
## Fix
Replace `default_stream()` with `current_stream()` in the
synchronization call. Since `__add_grad_to_ipg_bucket` is called from
DeepSpeed's gradient hook (which executes on the same stream as the
backward op), `current_stream()` correctly captures whatever stream
actually computed the gradient:
```python
self.reduce_and_partition_stream.wait_stream(get_accelerator().current_stream())
```
This is a zero-overhead change — `wait_stream` uses CUDA events
internally, which cost ~1µs per parameter regardless of which stream is
waited on.
## Verification
Tested with Qwen3-4B on 7×H200 GPUs, DeepSpeed 0.18.7, PyTorch 2.10.0,
CUDA 12.8, NCCL 2.27.5:
- **Before fix**: 150K+ NaN values across 55 weight layers after step 1,
grad_norm clipped to 1.0 (corrupted)
- **After fix**: 0 NaN across all weight layers for 3+ steps, grad_norm
healthy at 0.08–0.27
@tjruwase @stas00
Signed-off-by: Ubuntu <ubuntu@ip-172-31-16-215.us-east-2.compute.internal>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>