pytorch
cac4aa71 - Provide option to pass module instance to _load_state_dict_pre_hooks. (#62070)

Commit
3 years ago
Provide option to pass module instance to _load_state_dict_pre_hooks. (#62070) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62070 We have a custom Tensor: https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67, which doesn't show up in state_dict for the module. This was resolved by using the _register_state_dict_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196 to parse and add custom tensors to state_dict. However, the problem is during load time _register_load_state_dict_pre_hook: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272, does not pass in the module instance and as a result, a ShardedTensor in the state_dict cannot be appropriately added to a module at load time. To resolve this issue, in this PR I've enhanced this hook to support two variations, one which passes in the module instance (for the problem described above) and one is the previous version for BC reasons. ghstack-source-id: 134541391 Test Plan: 1) unit tests 2) waitforbuildbot Reviewed By: jbschlosser Differential Revision: D29867142 fbshipit-source-id: bcb136ff51eedd0b508cfb419e8b8a6b7d95539c
Author
Parents
Loading