Fixed einsum compatibility/performance issues (#46398)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46398
This PR makes torch.einsum compatible with numpy.einsum except for the sublist input option as requested here https://github.com/pytorch/pytorch/issues/21412. It also fixed 2 performance issues linked below and adds a check for reducing to torch.dot instead of torch.bmm which is faster in some cases.
fixes #45854, #37628, #30194, #15671
fixes #41467 with benchmark below
```python
import torch
from torch.utils.benchmark import Timer
a = torch.randn(10000, 100, 101, device='cuda')
b = torch.randn(10000, 101, 3, device='cuda')
c = torch.randn(10000, 100, 1, device='cuda')
d = torch.randn(10000, 100, 1, 3, device='cuda')
print(Timer(
stmt='torch.einsum("bij,bjf->bif", a, b)',
globals={'a': a, 'b': b}
).blocked_autorange())
print()
print(Timer(
stmt='torch.einsum("bic,bicf->bif", c, d)',
globals={'c': c, 'd': d}
).blocked_autorange())
```
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa37c413850>
torch.einsum("bij,bjf->bif", a, b)
Median: 4.53 ms
IQR: 0.00 ms (4.53 to 4.53)
45 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa37c413700>
torch.einsum("bic,bicf->bif", c, d)
Median: 63.86 us
IQR: 1.52 us (63.22 to 64.73)
4 measurements, 1000 runs per measurement, 1 thread
```
fixes #32591 with benchmark below
```python
import torch
from torch.utils.benchmark import Timer
a = torch.rand(1, 1, 16, 2, 16, 2, 16, 2, 2, 2, 2, device="cuda")
b = torch.rand(729, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, device="cuda")
print(Timer(
stmt='(a * b).sum(dim = (-3, -2, -1))',
globals={'a': a, 'b': b}
).blocked_autorange())
print()
print(Timer(
stmt='torch.einsum("...ijk, ...ijk -> ...", a, b)',
globals={'a': a, 'b': b}
).blocked_autorange())
```
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7efe0de28850>
(a * b).sum(dim = (-3, -2, -1))
Median: 17.86 ms
2 measurements, 10 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7efe0de286a0>
torch.einsum("...ijk, ...ijk -> ...", a, b)
Median: 296.11 us
IQR: 1.38 us (295.42 to 296.81)
662 measurements, 1 runs per measurement, 1 thread
```
TODO
- [x] add support for ellipsis broadcasting
- [x] fix corner case issues with sumproduct_pair
- [x] update docs and add more comments
- [x] add tests for error cases
Test Plan: Imported from OSS
Reviewed By: malfet
Differential Revision: D24860367
Pulled By: heitorschueroff
fbshipit-source-id: 31110ee598fd598a43acccf07929b67daee160f9