[FSDP2] Added gradient accumulation w/o reduction (#118298)
This PR adds a way to do gradient accumulation without collectives (i.e. reduce-scatter for FSDP and reduce-scatter/all-reduce for HSDP, though HSDP is not yet implemented). Since the `no_sync()` context manager has received some feedback, we simply define a method on the module to set whether the module requires gradient synchronization or not, where this method can recurse or not.
```
# Before with `no_sync()`:
with fsdp_model.no_sync() if not is_last_microbatch else contextlib.nullcontext():
# Forward/backward
# After with a setter:
fsdp_model.set_requires_gradient_sync(not is_last_microbatch)
# Forward/backward
```
Having the method be able to recurse or not also gives some flexibility. For example, some large modules can still reduce-scatter, while some smaller modules can avoid it to save communication bandwidth:
```
fsdp_modules_to_reduce_scatter: Set[nn.Module] = ...
for module in fsdp_model.modules():
if isinstance(module, FSDP) and module not in fsdp_modules_to_reduce_scatter:
module.set_requires_gradient_sync(not is_last_microbatch)
# Forward/backward
```
(Separately, we may expose a helper for `return [module for model.modules() if isinstance(module, FSDP)]`.)
---
To show the spirit of this API choice, I also included `set_requires_all_reduce` that would give us the ability to only reduce-scatter but not all-reduce for HSDP (originally from the MiCS paper). If we want to flexibly support heterogeneous sharding where FSDP is applied to some modules and HSDP to others in the same model, then having a module-level method that has the option to not recurse makes sense to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118298
Approved by: https://github.com/wconstab, https://github.com/wanchaol
ghstack dependencies: #119550, #118136, #118223, #118755, #119825