[FSDP] Optimize away intermediate `div_` for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.
$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$
Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.
### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$
This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.
### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.
Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma