Provide option to pass module instance to _load_state_dict_pre_hooks. (#62070)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62070
We have a custom Tensor:
https://github.com/pytorch/pytorch/blob/master/torch/distributed/_sharded_tensor/api.py#L67,
which doesn't show up in state_dict for the module. This was resolved by
using the _register_state_dict_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1196
to parse and add custom tensors to state_dict.
However, the problem is during load time _register_load_state_dict_pre_hook:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L1272,
does not pass in the module instance and as a result, a ShardedTensor in the
state_dict cannot be appropriately added to a module at load time.
To resolve this issue, in this PR I've enhanced this hook to support two
variations, one which passes in the module instance (for the problem described
above) and one is the previous version for BC reasons.
ghstack-source-id: 134541391
Test Plan:
1) unit tests
2) waitforbuildbot
Reviewed By: jbschlosser
Differential Revision: D29867142
fbshipit-source-id: bcb136ff51eedd0b508cfb419e8b8a6b7d95539c