TensorIterator::binary_op input-output overlap check (#24058)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/8212
This fix is based on the idea that in-place ops(e.g. add_(...)) and out ops(e.g. tensor.add(..., out=...)) must check that the output tensor does not partially overlap with any of it's input tensors. Otherwise the result of such op is unexpected to the user. Since TensorIterator is a common backend for such ops and it's already used to check output self-overlapping, this fix is implemented in the same place.
MemOverlapStatus enum class is introduced to model two tensors overlapped state:
- TOO_HARD if at least one of them is not contiguous
- FULL if both are contiguous and share exactly the same memory array [data(), data() + numel() *itemsize()]
- PARTIAL is both are contiguous but underlying memory is shared partially, in other words memory arrays overlap but not identical.
- NO if both are contiguous but have independent non overlapping memory arrays
Performance test of clone/addcmul_/addcdiv_ with check_mem_overlaps:
a = torch.empty(10000000, device='cpu')
b = torch.randn(10000000, device='cpu')
timeit a.copy_(b)
master: 10.3 ms ± 429 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
branch: 10.2 ms ± 946 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
a = torch.empty(10000000, device='cuda')
b = torch.randn(10000000, device='cuda')
timeit a.copy_(b)
master: 373 µs ± 97.9 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
branch: 373 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
a = torch.randn(1000000, device='cpu')
b = torch.randn(1000000, device='cpu')
c = torch.randn(1000000, device='cpu')
timeit a.addcmul_(b, c)
master: 2.02 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
branch: 2.11 ms ± 200 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
a = torch.randn(1000000, device='cuda')
b = torch.randn(1000000, device='cuda')
c = torch.randn(1000000, device='cuda')
timeit a.addcmul_(b, c)
master: 72.6 µs ± 627 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 72.4 µs ± 18.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
a = torch.randn(1000000, device='cpu')
b = torch.randn(1000000, device='cpu')
c = torch.randn(1000000, device='cpu')
timeit a.addcdiv_(b, c)
master: 2.19 ms ± 583 µs per loop (mean ± std. dev. of 7 runs, 1000 loop each)
branch: 1.97 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
a = torch.randn(1000000, device='cuda')
b = torch.randn(1000000, device='cuda')
c = torch.randn(1000000, device='cuda')
timeit a.addcdiv_(b, c)
master: 71.3 µs ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 71.7 µs ± 3.96 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
a = torch.empty(100, device='cpu')
b = torch.randn(100, device='cpu')
timeit a.copy_(b)
master: 12.1 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
branch: 11.1 µs ± 61.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
a = torch.empty(100, device='cuda')
b = torch.randn(100, device='cuda')
timeit a.copy_(b)
master: 20.9 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 22.8 µs ± 2.63 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
a = torch.randn(100, device='cpu')
b = torch.randn(100, device='cpu')
c = torch.randn(100, device='cpu')
timeit a.addcmul_(b, c)
master: 24.1 µs ± 2.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 24 µs ± 91.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
a = torch.randn(100, device='cuda')
b = torch.randn(100, device='cuda')
c = torch.randn(100, device='cuda')
timeit a.addcmul_(b, c)
master: 34.5 µs ± 4.82 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 29.8 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
a = torch.randn(100, device='cpu')
b = torch.randn(100, device='cpu')
c = torch.randn(100, device='cpu')
timeit a.addcdiv_(b, c)
master: 21.3 µs ± 210 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 23.8 µs ± 403 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
a = torch.randn(100, device='cuda')
b = torch.randn(100, device='cuda')
c = torch.randn(100, device='cuda')
timeit a.addcdiv_(b, c)
master: 30.3 µs ± 257 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
branch: 31.8 µs ± 214 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24058
Differential Revision: D16767892
Pulled By: pbelevich
fbshipit-source-id: 0cdaaa471d003a2886b1736f8985842226b8493a