[FSDP2] Added `reshard_after_forward` (#118017)
This PR adds the `reshard_after_forward: Union[bool, int]` arg and a `reshard()` method. The `reshard_after_forward` argument trades off communication and memory.
- `reshard_after_forward=True`: reshard parameters after forward; unshard (all-gather) in backward
- `reshard_after_forward=False`: no reshard of parameters after forward; no unshard (all-gather) in backward
- `reshard_after_forward: int`: reshard parameters to a smaller world size; unshard (all-gather) over small world size in backward
In comparison with DeepSpeed and existing FSDP:
- `reshard_after_forward=True` == `FULL_SHARD` == ZeRO-3
- `reshard_after_forward=False` == `SHARD_GRAD_OP` == ZeRO-2
- `reshard_after_forward=8` == ZeRO++
ZeRO-1 is `reshard_after_after_forward=False` without gradient reduction (implemented in a later PR). If we need gradient reduction on an iteration, then ZeRO-2 supersedes ZeRO-1.
We prefer a simple state transition between `SHARDED` / `SHARDED_POST_FORWARD` and `UNSHARDED`, where the state directly defines what tensors are registered to the module. In particular, we _do not_ have a state where the sharded parameters are registered but the unsharded parameters are still in GPU memory. This greatly simplifies our state transitions, but it means that parameters may be non-intuitively registered to the module (e.g. if only the root does not reshard after forward, then the root will be the only without sharded parameters registered). To address this, we introduce a simple `reshard()` method that can force-reshard the parameters. This makes sense to me because the typical case does not care about the registered parameters after forward (in fact, for existing FSDP with `use_orig_params=False`, the unsharded parameters are still registered and are dangling tensors without storage.)
I plan to expose a complementary `unshard(async_op: bool = True)` method in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118017
Approved by: https://github.com/weifengpy, https://github.com/wanchaol