pytorch
66f66fac - [jiterator] Reduce templating in jitted_gpu_kernel_impl (#80103)

Commit
2 years ago
[jiterator] Reduce templating in jitted_gpu_kernel_impl (#80103) Previously, a new `jitted_gpu_kernel_impl` was instantiated for every combination of kernel and data types. This adds a new intermediate, `jitted_gpu_kernel_generic`, which is only templated on the arity of the input function. So, the compiler is free to re-use this code between different kernels. `UnaryOperators.cu` as an example will only need to compile one version. This is achieved by: 1. Hoisting static variables out of the `launch_` functions and into `JittedKernelVariantCache`, stored in `jitted_gpu_kernel_impl`, which is templated on the kernel name and dtypes. 2. Moving arguments describing the kernel's static properties (e.g. `name` and `f_inputs_type`) into runtime variables which are packaged into a new `jit::KernelDescriptor` struct. 3. changing `extra_args` from a tuple to `c10::ArrayRef<void*>` We can expect benefits in both binary size and compile times. On my build, I see an 11 MB reduction in binary size for `libtorch_cuda.so` and this saving scales linearly with the number of jiterated kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80103 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading