pytorch
60174888 - [MPS] LSTM fixes (#95388)

Commit
2 years ago
[MPS] LSTM fixes (#95388) * [MPS] Fix LSTM backward and forward pass (#95137) Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: https://github.com/pytorch/pytorch/pull/95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer * Update the allowlist for lstm_mps_backward * More update to the BC allowlist --------- Co-authored-by: alexdremov <dremov.me@gmail.com> Co-authored-by: albanD <desmaison.alban@gmail.com>
Author
Parents
Loading