pytorch
d1cecd9c - Add assign kwarg to module.load_state_dict (#102212)

Commit
2 years ago
Add assign kwarg to module.load_state_dict (#102212) Fixes #64601 and #98906 Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`. Primarily intended to remove the need for the `.to_empty()` in ``` with torch.device('meta'): m = SomeModule() m.to_empty() state_dict = torch.load('...pth') m.load_state_dict(state_dict) ``` so we can instead do ``` with torch.device('meta'): m = SomeModule() state_dict = torch.load('...pth') m.load_state_dict(state_dict, assign=True) ``` **A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?** What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these. One solution would be to make these empty tensors but it might not be semantically correct... Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212 Approved by: https://github.com/albanD
Committer
Parents
Loading