[DDP][PT2D] Allreduce fusion fx pass using concat and all_reduce_coalesced (#113209)
Differential Revision: [D49858057](https://our.internmc.facebook.com/intern/diff/D49858057/)
**TL;DR**
This PR implements 2 different DDP all_reduce fusions in Inductor post_grad fx passes. The two fusions are 1) fusion with concat op and 2) fusion with all_reduce_coalesced. When DDP detects that Python reducer is being used, DDP will automatically turn on the fusion.
This PR does not invent any algorithm and simply reflects the bucket size users set to DDP.
**Implementation Details**
*Fusion with concat op*
The idea of this fusion is to use a concat op to concatenate all the gradients into one tensor and perform one `all_reduce`. After the `wait` op of the `all_reduce`, splitting and reshaping will also be perform to get the individual gradient.
Because DDP needs to perform gradient scaling, the benefit of using this fusion is that we could perform the gradient scaling over the the concatenated buffer.
*Fusion with `all_reduce_coalesced`*
The idea of this fusion is to use `all_reduce_coalesced` op to directly perform the `all_reduce` over multiple buffers. This avoid the copy overhead but may not achieve the best NCCL performance. In addition, because there are multiple buffers, we could not do one simple gradient scaling but have to rely on `foreach_div` to help the gradient scaling.
**Limitations**
Current fusions do not distinguish `all_reduce` generated by different DDP modules. This is okay if all DDP instances use the same PG and data type. The support of multiple DDP instances with different PG and data type will come in the later PRs.
**TODOs**
- [x] Implement DDP allreduce fusion algorithm for Inductor post_grad pass.
- [ ] Add unit tests to ensure the fusion doesn't DDP + TP.
- [ ] Group different PG and data type of `all_reduce`s.
- [ ] Mixed precision supports and tests
- [ ] Implement the fusions with Inductor IR.
- [ ] Add auto bucketing based on Inductor profiling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113209
Approved by: https://github.com/yf225