pytorch
d172dcf3 - [FSDP()][21/N] Refactor and fix `_cast_buffers()` (#87935)

Commit
2 years ago
[FSDP()][21/N] Refactor and fix `_cast_buffers()` (#87935) This PR refactors and fixes `_cast_buffers()`. **Before** Buffers were not correctly cast back to their original dtypes for submodules when using buffer mixed precision. - `_cast_buffers(recurse=False)` incorrectly casts all buffers, including those in submodules. This is because of this outer loop over `self.modules()`: https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L700 - There was a unit test that checked that buffers were cast as expected (`test_mixed_precision_e2e_full_shard()`). The unit test _coincidentally_ passed because all modules shared the same buffer name `"buffer"`. In `_cast_buffers()`, the `dict` mapping buffer name to original dtype is populated lazily (during `_lazy_init()`). However, the keys are unprefixed: https://github.com/pytorch/pytorch/blob/c40033be162db0f94d37e7ccbd2a89d67f8b8e47/torch/distributed/fsdp/fully_sharded_data_parallel.py#L712-L717 - Thus, even though (1) `_cast_buffers(recurse=False)` was only called on the root and (2) `self._buffer_name_to_orig_dtype` had unprefixed names as keys, the unit test still passed because (1) `_cast_buffers()` still looped over all buffers despite `recurse=False` and (2) all submodules' buffers were named `"buffer"` and had the same original and low-precision dtypes and hence were cast correctly. If we change each submodule to have its own distinct buffer name, then the unit test fails. This PR makes such a change to showcase the progression granted by this PR. **After** This PR separates `_cast_buffers()` into three methods: `_get_buffers_and_dtypes_for_computation()`, `_get_buffers_and_dtypes_for_checkpoint()`, and `_cast_buffers_to_dtype_and_device()`. This is to separate the different use cases (casting for computation and casting for checkpointing) and the corresponding code paths. Plus, the signature for `_cast_buffers_to_dtype_and_device()` makes it clear exactly what buffers are being cast and to what dtype. Both `_get_...()` functions assume that they are called on the root only for now. This coincides with the construction of `_buffer_name_to_orig_dtype` in the FSDP constructor, which loops over all submodules. (This means that for non-root modules, their `_buffer_name_to_orig_dtype` is populated but not used.) The `dict`'s keys are clean since the buffer cast to original dtype happens in a `summon_full_params()` context, which cleans the names. **Follow-Ups** - We can try to move `_get_buffers_and_dtypes_for_checkpoint()` into `_state_dict_utils.py` in a follow-up. - We may want to move to per-module buffer casting (i.e. do not have the root module cast for all submodules). Pull Request resolved: https://github.com/pytorch/pytorch/pull/87935 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading