[Bugfix] Resolve Rank index out of range during BWD when sp_size < world_size in Ulysses (#7809)
### Description
This PR addresses Issue #7672.
When sequence_parallel_size is smaller than world_size (e.g., sp_size=2
on 4 GPUs) with PyTorch < 2.3, using
torch.distributed.nn.functional.all_gather for loss aggregation triggers
an IndexError: tuple index out of range during the backward pass. This
is due to a known PyTorch issue where the backward hook accesses the
global rank instead of the group rank.
### Solution
1. Regression Test & Workaround: Updated the regression test
TestUlyssesLossBackward to implement a Weighted All-Reduce pattern.
- Before: all_gather -> manual sum (Vulnerable to rank indexing mismatch
on older PyTorch).
- After: all_reduce(weighted_loss) / all_reduce(total_weight) (Robust
and supports weighted averaging).
2. Runtime Warning: Added a version check (required_torch_version) in
DeepSpeedEngine. It now logs a warning if Sequence Parallelism is
enabled on PyTorch < 2.3, providing a link to the workaround test case.
3. Documentation: Updated ulysses-alst-sequence-parallelism.md with a
note regarding legacy PyTorch versions and the recommended workaround.
### Verification
Added and verified the regression test
tests/unit/sequence_parallelism/test_ulysses.py which now validates the
weighted averaging logic.
**1. Reproduction (Before Fix)**
Confirmed IndexError crash on Rank 2/3 with sp_size=2 on a 4-GPU setup.
<img width="1370" height="860" alt="Screenshot 2026-01-23 at 23 53 42"
src="https://github.com/user-attachments/assets/f4005c02-ff6c-46ea-a1a7-caac2093128b"
/>
**2. Verification (After Fix)**
Verified the fix using the regression test logic on 4x RTX A6000. The
backward pass now completes successfully on all ranks without error.
<img width="1192" height="605" alt="Screenshot 2026-01-23 at 23 52 54"
src="https://github.com/user-attachments/assets/c14cd093-67b7-42b0-ae15-65555c129082"
/>
---------
Signed-off-by: vensen <vensenmu@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>