[FSDP] Introduce `FlatParamHandle` (#79652)
**Overview**
This PR introduces `FlatParamHandle` to enable non-recursive FSDP wrapping. The class absorbs the unflattening/flattening logic from `FlattenParamsWrapper` but does not require wrapping a particular `nn.Module`.
## Discussion
### Introducing `FlatParamHandle`
There is flexibility in the design space for how to allocate attributes and methods to `FlatParameter` versus a wrapping class like `FlatParamHandle` or `FlattenParamsWrapper`. Several points in the design space provide the same functionality, so deciding on an allocation is arguably stylistic, though then preference should be given to cleaner designs.
The forefront consideration is that a `FlatParameter`'s metadata should be initialized once, while its data may be reloaded via checkpointing. This motivates decoupling the metadata initialization from the `FlatParameter` constructor, which should instead only handle the parameter data. Thus, we have both a `FlatParamHandle` managing a `FlatParameter` and the `FlatParameter` itself.
```
class FlatParamHandle:
def __init__(self, module: nn.Module, params: Sequence[nn.Parameter]):
# Calls `_init_flat_param()`
def _init_flat_param(self, module: nn.Module, params: Sequence[nn.Parameter]):
# Calls `flatten_params()` and initializes metadata
@staticmethod
def flatten_params(params: Sequence[torch.Tensor], requires_grad: bool) -> FlatParameter:
# Also may be used for checkpoint reloading
class FlatParameter(nn.Parameter):
# Constructor is not overridden
```
Under this separation with `FlatParameter` as solely as a data container, we keep methods manipulating `FlatParameter` on the `FlatParamHandle`. Because `FlatParameter`'s constructor is not overridden, we should be able to replace it with another tensor type e.g. `ShardedTensor` with minimal changes.
### Compatibility with `FlattenParamsWrapper`
To ensure backward compatibility, `FlattenParamsWrapper` now holds a `FlatParamHandle`. Existing logic from `FlattenParamsWrapper` simply routes to the handle now.
A `FullyShardedDataParallel` instance holds references to all of its handles.
- For the recursive-wrapping paradigm, there is at most one handle, which is from its `FlattenParamsWrapper` if it manages parameters.
- For the non-recursive wrapping paradigm, there may be multiple handles, all owned by the single (root) `FullyShardedDataParallel` instance.
## For Reviewers
### `FlatParameter` Construction
In the existing implementation, a `FlatParameter`'s metadata was partially initialized in its constructor (e.g. `_param_numels`, `_param_shapes`) and partially initialized by the owning `FlattenParamsWrapper` (e.g. `_param_infos`, `_shared_param_infos`). The latter part was needed due to requiring module information. With this PR, the metadata initialization is consolidated in `FlatParamHandle`.
- During model construction, a `FlatParameter` should be initialized via the handle constructor`FlatParamHandle(params, module)`.
- During sharded checkpoint loading, a `FlatParameter` should be initialized via the static method `FlatParamHandle.flatten_params(new_params)`.
- The checkpointing implementation is responsible for checking that `new_params` used to construct the `FlatParameter` data to load is consistent with the existing `FlatParameter`'s metadata.
These are the only two cases for `FlatParameter` construction right now, so there is no real functionality regression by not recomputing some of the metadata in the `FlatParameter` constructor. The `nn.Module.state_dict()` is implemented using in-place `copy_()`, so the new loaded `FlatParameter`'s metadata *should* match the existing `FlatParameter`'s metadata for correctness anyway. (I.e. we do not support a usage where we reload a `FlatParameter` with differing metadata into an existing `FlatParameter`.)
### BC Breaking
- `ShardMetadata` -> `FlatParamShardMetadata` to avoid name conflict with `ShardedTensor`
- `metadata()` -> removed (unused)
- `FlatParameter` attributes
- `_param_numels` -> `_numels`
- `_param_shapes` -> `_shapes`
- `_param_names` -> `_prefixed_param_names`
- `full_numel` -> `_unsharded_size.numel()`
- `_param_indice_in_shard` -> `_shard_indices`
- `_sharded_param_offsets` -> `_shard_param_offsets`
- `num_padded` -> `_shard_numel_padded`
- `param_offsets` -> not saved; directly constructed in `_get_flat_param_offsets()` and used once
- `FlattenParamsWrapper` `param_list` argument -> `params` for consistency with `FlatParameter`
## Follow-Ups
- The current `FlatParameter`'s `data` represents either the sharded unflattened parameter, unsharded unflattened parameter, or reduced-precision sharded unflattened parameter, depending dynamically on the runtime context. When its `data` represents one quantity, the other quantities are still saved as attributes on the `FlatParameter` (e.g. `_local_shard`, `_full_param_padded`, `_mp_shard`). `FullyShardedDataParallel` directly manipulates the `data`.
We should investigate the tradeoffs of having those attributes on the `FlatParameter` versus moving them to the `FlatParamHandle`. The motivation for the latter is to define a clean interface for `FullyShardedDataParallel` to manage parameter data in preparation for generalizing to multiple parameter groups, to managing non-`FlatParameter`s, and to supporting non-CUDA devices. (More explicitly, `FullyShardedDataParallel`'s parameter *variables* would be set to different `Tensor` variables, none of which own another, instead of `FullyShardedDataParallel`'s parameter variables' *data* being set to different `Tensor` variables, all owned by the `FlatParameter`, and the data management would be folded into handle, hidden from `FullyShardedDataParallel`.)
- We should investigate if we can coalesce the remaining logic in `FlattenParamsWrapper` into `FullyShardedDataParallel` and remove `FlattenParamsWrapper`.
- We may want to move the mixed precision logic to the handle instead of the `FullyShardedDataParallel` instance to enable per-`FlatParameter` mixed precision instead of per-`FullyShardedDataParallel`. Otherwise, the non-recursive wrapping path is bound to all-or-nothing mixed precision.
Differential Revision: [D37250558](https://our.internmc.facebook.com/intern/diff/D37250558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79652
Approved by: https://github.com/zhaojuanmao, https://github.com/fegin, https://github.com/rohan-varma