pytorch
b9f099ed - Make TensorIterator stop promoting types by copying (#28427)

Commit
5 years ago
Make TensorIterator stop promoting types by copying (#28427) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28427 Fixes: https://github.com/pytorch/pytorch/issues/26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0670eb1f9098a7e098e93b20453e8b5c9f 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f654cba9b8c569f0bcd583732bbc891f80b2 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: https://github.com/pytorch/pytorch/pull/28344 Test Plan: Imported from OSS Differential Revision: D18170997 Pulled By: ezyang fbshipit-source-id: 9c82c1c89583f3e6202c5d790b9b73ad9f960fad
Author
Parents
Loading