pytorch
80790ece - [einsum] Call view instead of sum to remediate MPS regression (#87135)

Commit
2 years ago
[einsum] Call view instead of sum to remediate MPS regression (#87135) Fixes #87010. It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible. Benchmarking results show that, on MPS, we would be going from the following code taking **29.89ms instead of the current 1466ms, almost a 50x speedup**. ``` q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) torch.einsum('b i d, b j d -> b i j', q, k).max().item() ``` And a regular einsum will now take **.506ms instead of 2.76ms.** ``` q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) torch.einsum('b i d, b j d -> b i j', q, k) ``` Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!! Pull Request resolved: https://github.com/pytorch/pytorch/pull/87135 Approved by: https://github.com/soulitzer, https://github.com/malfet, https://github.com/albanD
Author
Committer
Parents
Loading