[SPMD] Expedite the allreduce call before doing comm_fusion (#98922)
The allreduce call order and gradients order may be different and can interfere the benefit of comm_fusion. This PR reorders the graph so that all the allreduce calls happen right after its last input.
Differential Revision: [D44900738](https://our.internmc.facebook.com/intern/diff/D44900738/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98922
Approved by: https://github.com/mrshenli