pytorch
b9e95158 - [MPS] Fix LSTM backward and forward pass (#95137)

Commit
1 year ago
[MPS] Fix LSTM backward and forward pass (#95137) Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: https://github.com/pytorch/pytorch/pull/95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
Author
Committer
Parents
Loading