pytorch
401109a2 - Use int64_t for indexing in `multi_tensor_apply` (#101760)

Commit
1 year ago
Use int64_t for indexing in `multi_tensor_apply` (#101760) Fixes #101449 I found it better to either imitate the combo of `TensorIterator::can_use_32bit_indexing` and `TensorIterator::with_32bit_indexing` or adroitly choose the index type depending on `Tensor::numel` in the future. --- Used `nsys nvprof` to casually see the effect of `int64_t` indexing: ```python import torch params = [ {"params": [torch.randn(32, 32, device="cuda") for _ in range(100)]}, {"params": [torch.randn(32, 32, device="cuda") for _ in range(100)]}, ] grads = [ [torch.randn(32, 32, device="cuda") for _ in range(100)], [torch.randn(32, 32, device="cuda") for _ in range(100)], ] optimizer = torch.optim.Adam(params, fused=True) for _ in range(100): for i, param_groups in enumerate(params): for p, g in zip(param_groups["params"], grads[i]): p.grad = g optimizer.step() optimizer.zero_grad() ``` Environment ``` Collecting environment information... PyTorch version: 2.1.0a0+gitf994d0b Is debug build: False CUDA used to build PyTorch: 12.1 Python version: 3.10.9 (main, May 17 2023, 00:46:40) [GCC 11.3.0] (64-bit runtime) CUDA runtime version: 12.1.105 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB ``` --- - `multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensor` -> 1.02x - `multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in…` -> 1.04x Current main branch: ``` Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 64.9 5787610 600 9646.0 9632.0 9503 9888 52.9 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensorLi… ... 8.1 720575 200 3602.9 3584.0 3551 4320 63.4 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in… ``` this PR: ``` Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 65.0 5876847 600 9794.7 9792.0 9632 10080 58.1 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensorLi… ... 8.3 748313 200 3741.6 3744.0 3711 4479 60.0 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in… ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101760 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading