pytorch
b2311192 - [NN module] speed up _load_from_state_dict (#85743)

Commit
2 years ago
[NN module] speed up _load_from_state_dict (#85743) Fixes #61398 The original implementation is very slow when the state_dict.keys() is long. This PR only passes relevant keys to the child module. existing test passes: `pytest test/test_nn.py -k state_dict` I couldn't figure out a good way to write a new test for this new behavior. Had a new snippet, but it will be flaky if integrated into the main CI because it's a timing based check. But I can verify that the test took 30s to run, after this PR it only takes 0.5s. ```python def test_load_state_dict_large(self): # construct a module with 4 levels of module, 10 linear each, leads to 10k items in the dictionary import copy import time base_module = nn.Linear(1,1) model = base_module for level in range(4): model = nn.Sequential(*[copy.deepcopy(model) for _ in range(10)]) state_dict = model.state_dict() self.assertEqual(len(state_dict), 20000) st = time.time() model.load_state_dict(state_dict, strict=True) strict_load_time = time.time() - st # it took 0.5 seconds to self.assertLess(strict_load_time, 10) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/85743 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading