pytorch
58703461 - Port index_copy from TH to ATen (#52203)

Commit
3 years ago
Port index_copy from TH to ATen (#52203) Summary: The design of the `TensorIterator` was similar to that in https://github.com/pytorch/pytorch/pull/50578 Resolves https://github.com/pytorch/pytorch/issues/24670 Resolves https://github.com/pytorch/pytorch/issues/24523 Timings: <details> <summary>Script</summary> ```python from IPython import get_ipython import torch torch.manual_seed(13) torch.set_num_threads(1) ipython = get_ipython() cpu = torch.device('cpu') cuda = torch.device('cuda') def run_test(ndims, size, index_len, device): print(f"ndims: {ndims}, tensor_size: {size}, index_len: {index_len}, device: {device}") x = torch.rand(*([size] * ndims), device=device) index = torch.randint(size, (index_len,), dtype=torch.long, device=device) for d in range(ndims): shape_t = [size] * d + [index_len] + [size] * (ndims - d - 1) t = torch.rand(*shape_t, device=device) command = "x.index_copy(d, index, t)" if device == cuda: command = command + "; torch.cuda.synchronize()" ipython.magic(f"timeit {command}") print() run_test(3, 700, 10, cpu) run_test(3, 700, 100, cpu) run_test(3, 700, 700, cpu) run_test(2, 10000, 10000, cpu) run_test(3, 700, 10, cuda) run_test(3, 700, 100, cuda) run_test(3, 700, 700, cuda) run_test(2, 10000, 10000, cuda) ``` </details> <details> <summary>CPU ATen</summary> ``` ndims: 3, tensor_size: 700, index_len: 10, device: cpu 327 ms ± 309 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 329 ms ± 456 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 378 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ndims: 3, tensor_size: 700, index_len: 100, device: cpu 348 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 359 ms ± 330 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 526 ms ± 686 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ndims: 3, tensor_size: 700, index_len: 700, device: cpu 560 ms ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 552 ms ± 2.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 932 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ndims: 2, tensor_size: 10000, index_len: 10000, device: cpu 163 ms ± 5.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 302 ms ± 5.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>CUDA ATen</summary> ``` ndims: 3, tensor_size: 700, index_len: 10, device: cuda 9.63 ms ± 441 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 9.65 ms ± 230 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 12.4 ms ± 881 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ndims: 3, tensor_size: 700, index_len: 100, device: cuda 10.8 ms ± 1.51 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 11 ms ± 417 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 21.2 ms ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ndims: 3, tensor_size: 700, index_len: 700, device: cuda 19 ms ± 4.42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 17.8 ms ± 493 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 25.8 ms ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ndims: 2, tensor_size: 10000, index_len: 10000, device: cuda 5.59 ms ± 109 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 10 ms ± 25.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` </details> <details> <summary>CPU TH</summary> ``` ndims: 3, tensor_size: 700, index_len: 10, device: cpu 333 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 327 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 366 ms ± 753 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ndims: 3, tensor_size: 700, index_len: 100, device: cpu 336 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 345 ms ± 914 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) 884 ms ± 4.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ndims: 3, tensor_size: 700, index_len: 700, device: cpu 441 ms ± 3.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 514 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 7.46 s ± 6.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ndims: 2, tensor_size: 10000, index_len: 10000, device: cpu 141 ms ± 233 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 1.13 s ± 855 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` </details> <details> <summary>CUDA TH</summary> ``` ndims: 3, tensor_size: 700, index_len: 10, device: cuda 9.64 ms ± 390 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 9.68 ms ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 13.9 ms ± 928 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ndims: 3, tensor_size: 700, index_len: 100, device: cuda 11.6 ms ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 12.1 ms ± 3.72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 30.3 ms ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ndims: 3, tensor_size: 700, index_len: 700, device: cuda 27.2 ms ± 19.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 30.6 ms ± 43.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 146 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ndims: 2, tensor_size: 10000, index_len: 10000, device: cuda 6.5 ms ± 3.99 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 64.7 ms ± 55.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` </details> According to these we see a slight performance improvement across both CPU and GPU. cc: nikitaved Pull Request resolved: https://github.com/pytorch/pytorch/pull/52203 Reviewed By: jbschlosser Differential Revision: D27066572 Pulled By: mruberry fbshipit-source-id: 6101e461cf731afa3db042a383b723d3d6bfdc26
Author
Parents
Loading