[jit] Adding vectorized load/store support for JIT generated CUDA kernel (#36555)
Summary:
JIT pointwise kernel currently does not do vectorized load/store, which may lead to not optimal performance in shorter data types, like half and int8.
In this PR, a fixed length of 4 elements per load/store is added for supported tensor shape, implemented as a runtime check inside kernel.
Supported tensor shape:
- all input/output data point are aligned to 4*sizeof(dtype)
- last dimension contiguous(stride 1) and size is multiple of 4
- all other dimension have stride that is multiple of 4
All test_jit* passed, and here is performance result on a simple `ax+by+c` fusion
result before PR:
```
torch.float32 kernel time: 0.748 ms.
torch.float16 kernel time: 0.423 ms.
torch.int8 kernel time: 0.268 ms.
```
result after PR:
```
torch.float32 kernel time: 0.733 ms.
torch.float16 kernel time: 0.363 ms.
torch.int8 kernel time: 0.191 ms.
```
test code:
```
import torch
import time
# disable profiling to test all data types
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch.jit.script
def axpby(x, y):
return x * 2 - y * 3 + 1
for test_dtype in [torch.float32, torch.float16, torch.int8]:
a = torch.randn(12345,4096, device="cuda").to(test_dtype)
b = torch.randn(12345,4096, device="cuda").to(test_dtype)
# warm up
for _ in range(100):
c = axpby(a,b)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
c = axpby(a,b)
torch.cuda.synchronize()
end = time.time()
print("{} kernel time: {:.3f} ms.".format(test_dtype, end-start))
```
Generated code:
[log_with_generated_code.txt](https://github.com/pytorch/pytorch/files/4472813/log_with_generated_code.txt)
Additional note:
double type is disabled from vectorized code path.
We can later improve it with dynamic vectorization length support and less in-kernel check when we can use tensor shape information in codegen. For now, this implementation is following cache through TensorDesc mechanism, which does not have enough compile time information.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36555
Differential Revision: D21142762
Pulled By: ngimel
fbshipit-source-id: 1cfdc5807a944c4670b040dc2d2dfa480377e7d7