pytorch
769eca6f - Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (#93396)

Commit
1 year ago
Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (#93396) Fixes #93391 Thank you to the PyTorch Distributed team for your invaluable contributions to the PyTorch ecosystem, your work is immensely impressive and inspiring! As mentioned in #93391, in preparing the downstream package I maintain ([finetuning-scheduler](https://github.com/speediedan/finetuning-scheduler)) to support PyTorch 2.0's version of FSDP, I noticed modules that include multiple persistent buffers were not having their state properly transformed during saving of `state_dict`s. The issue was that the post-state_dict hook codepath shared by the `FULL_STATE_DICT` and `SHARDED_STATE_DICT` `_state_dict_type`s ([`_common_unshard_post_state_dict_hook`](https://github.com/pytorch/pytorch/blob/332d55d3df5ef22e47d3df73fa785f7ca4802169/torch/distributed/fsdp/_state_dict_utils.py#L158)) was inadvertently referencing a local variable (`buffer`) that was used in a [prior transformation](https://github.com/pytorch/pytorch/blob/332d55d3df5ef22e47d3df73fa785f7ca4802169/torch/distributed/fsdp/_state_dict_utils.py#L231), instead of the `buffers` variable that should have been referenced in the iteration context: https://github.com/pytorch/pytorch/blob/332d55d3df5ef22e47d3df73fa785f7ca4802169/torch/distributed/fsdp/_state_dict_utils.py#L251-L253 In this case, modules with a single persistent buffer or without mixed precision enabled would be unaffected. With multiple buffers and mixed precision enabled however, the issue may appear stochastically in proportion to the ratio of persistent buffers that have compatible dimensions (since the value of the last buffer visited in the ``buffer_names`` ``Set`` is copied to all buffers and the ``Set`` iteration order will of course vary) ```bash File ".../pytorch/torch/nn/modules/module.py", line 2028, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for FullyShardedDataParallel: size mismatch for _fsdp_wrapped_module.1._fsdp_wrapped_module.running_mean: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([10]). ``` To both address this issue and enhance coverage to avoid similar issues, this PR fixes the aforementioned typo and adds an additional set of basic tests that validate `state_dict` saving and loading for modules with persistent buffers in various contexts. I found that adding another model along with additional buffer-specific logic to adapt [`test_basic_save_and_load_state_dict`](https://github.com/pytorch/pytorch/blob/76b683b0087cf90bb201e9acabec05a85e683ab2/test/distributed/fsdp/test_fsdp_state_dict.py#L439) for the purposes of this coverage seemed to increase complexity of that test to an undesirable degree. Instead of adding additional complexity to that existing test, I've added a new test ``test_buffers_save_and_load_state_dict`` that does basic validation of ``state_dict`` saving and loading with mixed precision, ``state_dict_type`` and CPU offloading parameterization. Certainly let me know if you prefer I extend the logic of/add the persistent buffers model into the existing basic ``state_dict`` test, I'm happy to do so, just thought it was cleaner this way. Also, I thought doubling the number of tests with a ``use_orig_params`` parameterization or by testing additional different non-default buffer mixed precision data types was computationally imprudent but let me know if you'd like me to add those tests as well. The only other notable test change is that I've refactored ``TestFSDPStateDict._compare_models`` to accommodate both ``buffers`` and ``parameters`` comparisons without code duplication. Thanks again to the PyTorch Distributed team for your exceptional contributions. I've got some more to do adapting my package for 2.0's FSDP but it's been a delight so far thanks to your superlative work! Pull Request resolved: https://github.com/pytorch/pytorch/pull/93396 Approved by: https://github.com/rohan-varma, https://github.com/awgu, https://github.com/fegin
Author
Committer
Parents
Loading