Make TensorIterator stop promoting types by copying (#28427)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28427
Fixes: https://github.com/pytorch/pytorch/issues/26401
This PR fixes the issue by using the newly added dynamic cast inside
`TensorIterator` so that instead of converting the type at the beginning
(which generates extra kernel launches), the `TensorIterator` do a
load-cast-compute-store for each element while looping. So there is only
one read and one write of memory.
**nvprof:**
```python
import torch
_100M = 100 * 1024 ** 2
r = torch.randn(_100M, dtype=torch.float32, device='cuda')
d = torch.randn(_100M, dtype=torch.float64, device='cuda')
torch.cuda.synchronize()
torch.cuda.profiler.start()
r.add_(d)
torch.cuda.profiler.stop()
torch.cuda.synchronize()
```
```
==11407== NVPROF is profiling process 11407, command:
/home/xgao/anaconda3/bin/python simple.py
==11407== Profiling application: /home/xgao/anaconda3/bin/python
simple.py
==11407== Profiling result:
Type Time(%) Time Calls Avg Min
Max Name
GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms
2.0611ms
_ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_
API calls: 100.00% 1.05006s 1 1.05006s 1.05006s
1.05006s cudaLaunchKernel
0.00% 2.7740us 2 1.3870us 673ns
2.1010us cudaGetDevice
0.00% 2.3730us 1 2.3730us 2.3730us
2.3730us cudaSetDevice
0.00% 830ns 1 830ns 830ns
830ns cudaGetLastError
```
**benchmark**
```python
import torch
print(torch.__version__)
print(torch.version.git_version)
_100M = 100 * 1024 ** 2
r = torch.randn(_100M, dtype=torch.float32, device='cuda')
d = torch.randn(_100M, dtype=torch.float64, device='cuda')
torch.cuda.synchronize()
%timeit r.add_(d); torch.cuda.synchronize()
```
original
```
1.4.0a0+7d277b0
7d277b0670eb1f9098a7e098e93b20453e8b5c9f
6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
after
```
1.4.0a0+f0f2f65
f0f2f654cba9b8c569f0bcd583732bbc891f80b2
2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
For more benchmark, see: https://github.com/pytorch/pytorch/pull/28344
Test Plan: Imported from OSS
Differential Revision: D18170997
Pulled By: ezyang
fbshipit-source-id: 9c82c1c89583f3e6202c5d790b9b73ad9f960fad