[FSDP][8/N] Simplify addr padding internals (#97796)
This is a follow-up to the last PR to greatly simplify the approach. This should be much cleaner.
**Details**
Let `N` denote the number of original parameters flattened into a given flat parameter with `M` extra padding tensors.
- `_numels_with_padding`: length `N + M`
- `_is_padding_mask`: length `N + M`
- `_numels`, `_param_infos`, `_shapes`, `_fqns`, `_param_extensions`: length `N`
`_shard_param_indices` and `_shard_param_offsets` were used to determine (1) if a given original parameter is in the local shard and if so, then (2) what is its offset in the _sharded_ flat parameter, and (3) how many numel are in the _sharded_ flat parameter.
This PR reworks how to achieve (1), (2), and (3) to allow for simplifying the previously mentioned data structures. In particular, it saves one extra tuple `_shard_param_infos: Tuple[_ShardParamInfo, ...]` of length `N` where each `_ShardParamInfo` entry gives exactly the needed info. For example, the offset into the sharded flat parameter is now pre-computed, so we do not need to do `offset = 0; offset += numel_in_shard` over a `for` loop each time now.
For optimizer state dict, `FSDPParamInfo.param_indices` now maps to the indexes with respect to the length `N` data structures, not the length `N + M` ones. The only purpose of `param_indices` is to be able to index into `flat_param._shard_param_infos[i]` to get the contained info to flatten the unsharded original parameter optimizer state and extract the part in the local shard.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97796
Approved by: https://github.com/rohan-varma