pytorch
a5c65b86 - Fixed einsum compatibility/performance issues (#46398)

Commit
4 years ago
Fixed einsum compatibility/performance issues (#46398) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46398 This PR makes torch.einsum compatible with numpy.einsum except for the sublist input option as requested here https://github.com/pytorch/pytorch/issues/21412. It also fixed 2 performance issues linked below and adds a check for reducing to torch.dot instead of torch.bmm which is faster in some cases. fixes #45854, #37628, #30194, #15671 fixes #41467 with benchmark below ```python import torch from torch.utils.benchmark import Timer a = torch.randn(10000, 100, 101, device='cuda') b = torch.randn(10000, 101, 3, device='cuda') c = torch.randn(10000, 100, 1, device='cuda') d = torch.randn(10000, 100, 1, 3, device='cuda') print(Timer( stmt='torch.einsum("bij,bjf->bif", a, b)', globals={'a': a, 'b': b} ).blocked_autorange()) print() print(Timer( stmt='torch.einsum("bic,bicf->bif", c, d)', globals={'c': c, 'd': d} ).blocked_autorange()) ``` ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7fa37c413850> torch.einsum("bij,bjf->bif", a, b) Median: 4.53 ms IQR: 0.00 ms (4.53 to 4.53) 45 measurements, 1 runs per measurement, 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7fa37c413700> torch.einsum("bic,bicf->bif", c, d) Median: 63.86 us IQR: 1.52 us (63.22 to 64.73) 4 measurements, 1000 runs per measurement, 1 thread ``` fixes #32591 with benchmark below ```python import torch from torch.utils.benchmark import Timer a = torch.rand(1, 1, 16, 2, 16, 2, 16, 2, 2, 2, 2, device="cuda") b = torch.rand(729, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, device="cuda") print(Timer( stmt='(a * b).sum(dim = (-3, -2, -1))', globals={'a': a, 'b': b} ).blocked_autorange()) print() print(Timer( stmt='torch.einsum("...ijk, ...ijk -> ...", a, b)', globals={'a': a, 'b': b} ).blocked_autorange()) ``` ``` <torch.utils.benchmark.utils.common.Measurement object at 0x7efe0de28850> (a * b).sum(dim = (-3, -2, -1)) Median: 17.86 ms 2 measurements, 10 runs per measurement, 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7efe0de286a0> torch.einsum("...ijk, ...ijk -> ...", a, b) Median: 296.11 us IQR: 1.38 us (295.42 to 296.81) 662 measurements, 1 runs per measurement, 1 thread ``` TODO - [x] add support for ellipsis broadcasting - [x] fix corner case issues with sumproduct_pair - [x] update docs and add more comments - [x] add tests for error cases Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D24860367 Pulled By: heitorschueroff fbshipit-source-id: 31110ee598fd598a43acccf07929b67daee160f9
Parents
Loading