CUDA Loops: move address computation into policy, make policy.load load all arguments (#33720)
Summary:
So that in the future we can make policy accept an offset calculator in its constructor for the support of non-contiguous tensors.
The `elementwise_kernel_helper` is now very general and it can handle any cases:
```C++
template<typename func_t, typename policy_t>
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;
int idx = blockIdx.x;
return_t results[thread_work_size];
cuda9::workaround::enable_default_constructor<args_t> args_[thread_work_size];
args_t *args = reinterpret_cast<args_t *>(&args_);
// load
policy.load(args, idx);
// compute
#pragma unroll
for (int i = 0; i < thread_work_size; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}
}
// store
policy.store(results, idx);
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33720
Differential Revision: D20459652
Pulled By: ngimel
fbshipit-source-id: aa8b122e0e8c6e08ab354785e04753ff778882e2