pytorch
1c15cd48 - [FSDP][7/N] Add alignment padding for `use_orig_params=True` (#97667)

Commit
1 year ago
[FSDP][7/N] Add alignment padding for `use_orig_params=True` (#97667) This PR adds intra-`FlatParameter` 16-byte alignment padding to the `use_orig_params=True` code path to avoid clones in TorchInductor. **Approach** The `FlatParameter` maintains several data structures about its original parameters. Notably, the data structures `_param_infos`, `_shapes`, `_numels`, and `_fqns` have the same length and index in the same way. This PR treats alignment padding _like_ an original parameter in that the padding gets flattened into the `FlatParameter`. Therefore, it must be reflected in the aforementioned data structures. However, given the way in which the data structures are used, we choose to do the following if the `i`th tensor flattened into the `FlatParameter` is padding: - `_numels[i]` is the numel of padding - `_param_infos[i] == _shapes[i] == _fqns[i] == None` This choice is because (1) we must record the padding numel to account for it (e.g. for views) and (2) we prefer to preserve the invariant that the data structures index in the same way over avoiding `None` entries. To ease the burden of other FSDP developers, we separate the parameter flattening logic: - `_init_flat_param_and_metadata()`: This should be called only once in the `FlatParamHandle` constructor. The `FlatParameter` metadata is assumed to be static thereafter. - `flatten_tensors()` / `flatten_tensors_into_flat_param()`: These can be used for optimizer and model state dict and can be called after construction time. This separation allows `_init_flat_param_and_metadata()` to contain the much heavier metadata logic, while keeping the latter methods to be much lighter. The only constraint is that the alignment padding logic must be kept consistent between the two, but this should be worth the simper interface. **Testing** - This PR directly modifies the `use_orig_params=True` code path, so all existing tests passing gives good signal. - Some existing unit tests had to be adjusted to account for the alignment padding. - This PR adds two tests in `test_fsdp_flatten_params.py` to explicitly test the sharding metadata with alignment for both parameter full precision and mixed precision since the latter requires possibly more padding elements due to the decreased per-element size. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97667 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading