pytorch
0787d781 - Fix compatibility problem with LSTMs and torch.save (#57558)

Commit
3 years ago
Fix compatibility problem with LSTMs and torch.save (#57558) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57558 Fixes #53359 If someone directly saves an nn.LSTM in PyTorch 1.7 and then loads it in PyTorch 1.8, it errors out with the following: ``` (In PyTorch 1.7) import torch model = torch.nn.LSTM(2, 3) torch.save(model, 'lstm17.pt') (In PyTorch 1.8) model = torch.load('lstm17.pt') AttributeError: 'LSTM' object has no attribute 'proj_size' ``` Although we do not officially support this (directly saving modules via torch.save), it used to work and the fix is very simple. This PR adds an extra line to `__setstate__`: if the state we are passed does not have a `proj_size` attribute, we assume it was saved from PyTorch 1.7 and older and set `proj_size` equal to 0. Test Plan: I wrote a test that tests `__setstate__`. But also, Run the following: ``` (In PyTorch 1.7) import torch x = torch.ones(32, 5, 2) model = torch.nn.LSTM(2, 3) torch.save(model, 'lstm17.pt') y17 = model(x) (Using this PR) model = torch.load('lstm17.pt') x = torch.ones(32, 5, 2) y18 = model(x) ``` and finally compare y17 and y18. Reviewed By: mrshenli Differential Revision: D28198477 Pulled By: zou3519 fbshipit-source-id: e107d1ebdda23a195a1c3574de32a444eeb16191
Author
Parents
Loading