[FSDP] Enable async all-reduce for HSDP (#106080)
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.
For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.
**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">
</details>
<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">
</details>
For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
- **1160 ms / iteration**
- ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
- ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
- 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
- **665 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
- ~30 ms / encoder AR --> 2.34 GB/s bandwidth
- ~55 ms / decoder AR --> 2.65 GB/s bandwidth
- 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
- **597 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
- ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
- ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
- ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
- Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
- **556 ms / iteration**
- Speedup comes from faster overlapped AR
Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.
**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```
Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068