pytorch
662a8cf7 - [FSDP][8/N] Simplify addr padding internals (#97796)

Commit
1 year ago
[FSDP][8/N] Simplify addr padding internals (#97796) This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner. **Details** Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors. - `_numels_with_padding`: length `N + M` - `_is_padding_mask`: length `N + M` - `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N` `_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter. This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now. For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading