Allow vectorized gpu loop to have different argument types (#33222)
Summary:
Although currently the only user of GPU loops that has args with different dtypes is `where`, it sounds strange to restrict the args to have the same dtype. Allowing args to have different dtypes also makes it possible for me to clean up legacy code by reusing current code to implement unrolled GPU loop for non-contiguous tensors.
The stack storage of `elementwise_kernel_helper` is changed from `arg_t args[nt][arity]` to `traits:: ArgsTuple args[nt]`. Due to this change, we can no longer get element by `operator[]`, but instead we should use `std::get`. As a result, we can no longer unroll the loop wrt arity using pragma, but we have to
create a `static_unroll` to make use of template meta-programming to do the same job.
A good side effect of this change is, `invoke_with_array` is no longer needed and can be replaced with already existing `c10::guts::apply`. And we don't need the `namespace arg_type` workaround either. This makes the code less ugly.
The same approach might also work for ROCm loops, but I didn't change anything on ROCm in this PR, because I don't want potential compilation error or perf regression to delay this PR. But after this gets merged, I will try on ROCm and send a separate PR to make the code less diverge if the same approach trivially applies (trivially apply means a mindless copy-paste doesn't introduce unexpected compilation error or perf regression).
Assembly (https://github.com/zasdfgbnm/things/blob/master/2020Q1/disassembly-elementwise-vec.ipynb#33222):
```
**Symbol:**
void at::native::modern::elementwise_kernel<4, 64, 4, at::native::add_kernel_cuda(at::TensorIterator&, c10::Scalar)::{lambda()https://github.com/pytorch/pytorch/issues/1}::operator()() const::{lambda()https://github.com/pytorch/pytorch/issues/4}::operator()() const::{lambda(float, float)https://github.com/pytorch/pytorch/issues/1}, at::detail::Array<char*, 3> >(int, at::native::add_kernel_cuda(at::TensorIterator&, c10::Scalar)::{lambda()https://github.com/pytorch/pytorch/issues/1}::operator()() const::{lambda()https://github.com/pytorch/pytorch/issues/4}::operator()() const::{lambda(float, float)https://github.com/pytorch/pytorch/issues/1}, at::detail::Array<char*, 3>)
**ASM:**
.section .text._ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_,"ax",progbits
.sectioninfo @"SHI_REGISTERS=20"
.align 128
.global _ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_
.type _ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_,function
.size _ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_,(.L_40520 - _ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_)
.other _ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_,@"STO_CUDA_ENTRY STV_DEFAULT"
_ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_:
.text._ZN2at6native6modern18elementwise_kernelILi4ELi64ELi4EZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_NS_6detail5ArrayIPcLi3EEEEEviT2_T3_:
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 253
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ;
/*0010*/ @!PT SHFL.IDX PT, RZ, RZ, RZ, RZ ;
/*0020*/ S2R R9, SR_CTAID.X ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 39
/*0030*/ S2R R0, SR_TID.X ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 253
/*0040*/ IMAD.SHL.U32 R9, R9, 0x100, RZ ;
/*0050*/ IADD3 R5, -R9, c[0x0][0x160], RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 227
/*0060*/ SHF.R.S32.HI R17, RZ, 0x1f, R9 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 255
/*0070*/ ISETP.GE.AND P0, PT, R5, 0x100, PT ;
/*0080*/ @!P0 BRA `(.L_2919) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 227
/*0090*/ IMAD.SHL.U32 R12, R9.reuse, 0x4, RZ ;
/*00a0*/ SHF.L.U64.HI R17, R9, 0x2, R17 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 229
/*00b0*/ IADD3 R8, P0, R12.reuse, c[0x0][0x188], RZ ;
/*00c0*/ IADD3 R2, P1, R12, c[0x0][0x190], RZ ;
/*00d0*/ IADD3.X R9, R17.reuse, c[0x0][0x18c], RZ, P0, !PT ;
/*00e0*/ IADD3.X R3, R17, c[0x0][0x194], RZ, P1, !PT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 82
/*00f0*/ IMAD.WIDE R8, R0, 0x10, R8 ;
/*0100*/ IMAD.WIDE R2, R0, 0x10, R2 ;
/*0110*/ LDG.E.128.SYS R8, [R8] ;
/*0120*/ LDG.E.128.SYS R4, [R2] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 227
/*0130*/ IADD3 R12, P0, R12, c[0x0][0x180], RZ ;
/*0140*/ IADD3.X R13, R17, c[0x0][0x184], RZ, P0, !PT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 102
/*0150*/ IMAD.WIDE R12, R0, 0x10, R12 ;
//## File "/usr/include/c++/8/tuple", line 1315
/*0160*/ FFMA R7, R7, c[0x0][0x168], R11 ;
/*0170*/ FFMA R6, R6, c[0x0][0x168], R10 ;
/*0180*/ FFMA R5, R5, c[0x0][0x168], R9 ;
/*0190*/ FFMA R4, R4, c[0x0][0x168], R8 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 102
/*01a0*/ STG.E.128.SYS [R12], R4 ;
/*01b0*/ EXIT ;
.L_2919:
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*01c0*/ ISETP.GE.AND P0, PT, R0, R5, PT ;
/*01d0*/ BMOV.32.CLEAR RZ, B0 ;
/*01e0*/ BSSY B0, `(.L_2920) ;
/*01f0*/ IMAD.MOV.U32 R4, RZ, RZ, RZ ;
/*0200*/ CS2R R6, SRZ ;
/*0210*/ IMAD.MOV.U32 R8, RZ, RZ, RZ ;
/*0220*/ IMAD.MOV.U32 R10, RZ, RZ, RZ ;
/*0230*/ P0 BRA `(.L_2921) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*0240*/ IADD3 R3, P1, R9, R0, RZ ;
/*0250*/ LEA.HI.X.SX32 R6, R0, R17, 0x1, P1 ;
/*0260*/ LEA R2, P1, R3, c[0x0][0x188], 0x2 ;
/*0270*/ LEA.HI.X R3, R3, c[0x0][0x18c], R6, 0x2, P1 ;
/*0280*/ LDG.E.SYS R10, [R2] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 46
/*0290*/ IADD3 R6, R0, 0x40, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*02a0*/ ISETP.GE.AND P1, PT, R6, R5, PT ;
/*02b0*/ P1 BRA `(.L_2922) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*02c0*/ LDG.E.SYS R6, [R2+0x100] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 46
/*02d0*/ IADD3 R8, R0, 0x80, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*02e0*/ ISETP.GE.AND P1, PT, R8, R5, PT ;
/*02f0*/ P1 BRA `(.L_2923) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 46
/*0300*/ IADD3 R8, R0, 0xc0, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*0310*/ ISETP.GE.AND P1, PT, R8, R5, PT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*0320*/ LDG.E.SYS R8, [R2+0x200] ;
/*0330*/ @!P1 LDG.E.SYS R7, [R2+0x300] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 102
/*0340*/ P1 IMAD.MOV.U32 R7, RZ, RZ, RZ ;
/*0350*/ BRA `(.L_2921) ;
.L_2923:
/*0360*/ IMAD.MOV.U32 R7, RZ, RZ, RZ ;
/*0370*/ IMAD.MOV.U32 R8, RZ, RZ, RZ ;
/*0380*/ BRA `(.L_2921) ;
.L_2922:
/*0390*/ CS2R R6, SRZ ;
/*03a0*/ IMAD.MOV.U32 R8, RZ, RZ, RZ ;
.L_2921:
/*03b0*/ BSYNC B0 ;
.L_2920:
/*03c0*/ BMOV.32.CLEAR RZ, B0 ;
/*03d0*/ BSSY B0, `(.L_2924) ;
/*03e0*/ P0 BRA `(.L_2925) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*03f0*/ IADD3 R3, P1, R9, R0, RZ ;
/*0400*/ LEA.HI.X.SX32 R12, R0, R17, 0x1, P1 ;
/*0410*/ LEA R2, P1, R3, c[0x0][0x190], 0x2 ;
/*0420*/ LEA.HI.X R3, R3, c[0x0][0x194], R12, 0x2, P1 ;
/*0430*/ LDG.E.SYS R11, [R2] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 46
/*0440*/ IADD3 R12, R0, 0x40, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*0450*/ ISETP.GE.AND P1, PT, R12, R5, PT ;
/*0460*/ P1 BRA `(.L_2926) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*0470*/ LDG.E.SYS R13, [R2+0x100] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 46
/*0480*/ IADD3 R12, R0, 0x80, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*0490*/ ISETP.GE.AND P1, PT, R12, R5, PT ;
/*04a0*/ P1 BRA `(.L_2927) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*04b0*/ LDG.E.SYS R15, [R2+0x200] ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 46
/*04c0*/ IADD3 R12, R0, 0xc0, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 42
/*04d0*/ ISETP.GE.AND P1, PT, R12, R5, PT ;
/*04e0*/ P1 BRA `(.L_2928) ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 45
/*04f0*/ LDG.E.SYS R4, [R2+0x300] ;
/*0500*/ BRA `(.L_2928) ;
.L_2927:
/*0510*/ IMAD.MOV.U32 R15, RZ, RZ, RZ ;
/*0520*/ BRA `(.L_2928) ;
.L_2926:
/*0530*/ IMAD.MOV.U32 R15, RZ, RZ, RZ ;
/*0540*/ IMAD.MOV.U32 R13, RZ, RZ, RZ ;
/*0550*/ BRA `(.L_2928) ;
.L_2925:
/*0560*/ IMAD.MOV.U32 R15, RZ, RZ, RZ ;
/*0570*/ IMAD.MOV.U32 R13, RZ, RZ, RZ ;
/*0580*/ IMAD.MOV.U32 R11, RZ, RZ, RZ ;
.L_2928:
/*0590*/ BSYNC B0 ;
.L_2924:
//## File "/usr/include/c++/8/tuple", line 1315
/*05a0*/ P0 EXIT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 58
/*05b0*/ IADD3 R9, P0, R9, R0, RZ ;
//## File "/usr/include/c++/8/tuple", line 1315
/*05c0*/ FFMA R11, R11, c[0x0][0x168], R10 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 59
/*05d0*/ IADD3 R14, R0, 0x40, RZ ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 58
/*05e0*/ LEA.HI.X.SX32 R12, R0, R17, 0x1, P0 ;
/*05f0*/ LEA R2, P0, R9.reuse, c[0x0][0x180], 0x2 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 55
/*0600*/ ISETP.GE.AND P1, PT, R14, R5, PT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 58
/*0610*/ LEA.HI.X R3, R9, c[0x0][0x184], R12, 0x2, P0 ;
/*0620*/ STG.E.SYS [R2], R11 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 55
/*0630*/ P1 EXIT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 59
/*0640*/ IADD3 R10, R0, 0x80, RZ ;
//## File "/usr/include/c++/8/tuple", line 1315
/*0650*/ FFMA R13, R13, c[0x0][0x168], R6 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 55
/*0660*/ ISETP.GE.AND P0, PT, R10, R5, PT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 58
/*0670*/ STG.E.SYS [R2+0x100], R13 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 55
/*0680*/ P0 EXIT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 59
/*0690*/ IADD3 R0, R0, 0xc0, RZ ;
//## File "/usr/include/c++/8/tuple", line 1315
/*06a0*/ FFMA R15, R15, c[0x0][0x168], R8 ;
/*06b0*/ FFMA R7, R4, c[0x0][0x168], R7 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 55
/*06c0*/ ISETP.GE.AND P0, PT, R0, R5, PT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 58
/*06d0*/ STG.E.SYS [R2+0x200], R15 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 55
/*06e0*/ P0 EXIT ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh", line 58
/*06f0*/ STG.E.SYS [R2+0x300], R7 ;
//## File "/home/xgao/pytorch/aten/src/ATen/native/cuda/CUDALoops.cuh", line 260
/*0700*/ EXIT ;
.L_2929:
/*0710*/ BRA `(.L_2929);
/*0720*/ NOP;
/*0730*/ NOP;
/*0740*/ NOP;
/*0750*/ NOP;
/*0760*/ NOP;
/*0770*/ NOP;
.L_40520:
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33222
Differential Revision: D19964089
Pulled By: ngimel
fbshipit-source-id: a1e8e62d1ebcc67fb49f00d87c02bcdd13194024