Use int64_t for indexing in `multi_tensor_apply` (#101760)
Fixes #101449
I found it better to either imitate the combo of `TensorIterator::can_use_32bit_indexing` and `TensorIterator::with_32bit_indexing` or adroitly choose the index type depending on `Tensor::numel` in the future.
---
Used `nsys nvprof` to casually see the effect of `int64_t` indexing:
```python
import torch
params = [
{"params": [torch.randn(32, 32, device="cuda") for _ in range(100)]},
{"params": [torch.randn(32, 32, device="cuda") for _ in range(100)]},
]
grads = [
[torch.randn(32, 32, device="cuda") for _ in range(100)],
[torch.randn(32, 32, device="cuda") for _ in range(100)],
]
optimizer = torch.optim.Adam(params, fused=True)
for _ in range(100):
for i, param_groups in enumerate(params):
for p, g in zip(param_groups["params"], grads[i]):
p.grad = g
optimizer.step()
optimizer.zero_grad()
```
Environment
```
Collecting environment information...
PyTorch version: 2.1.0a0+gitf994d0b
Is debug build: False
CUDA used to build PyTorch: 12.1
Python version: 3.10.9 (main, May 17 2023, 00:46:40) [GCC 11.3.0] (64-bit runtime)
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
```
---
- `multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensor` -> 1.02x
- `multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in…` -> 1.04x
Current main branch:
```
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
64.9 5787610 600 9646.0 9632.0 9503 9888 52.9 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensorLi…
...
8.1 720575 200 3602.9 3584.0 3551 4320 63.4 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in…
```
this PR:
```
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
65.0 5876847 600 9794.7 9792.0 9632 10080 58.1 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensorLi…
...
8.3 748313 200 3741.6 3744.0 3711 4479 60.0 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in…
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101760
Approved by: https://github.com/ngimel