Adding elementwise kernel also operating on index (#28175)
Summary:
This PR add `gpu_kernel_with_index` as an addition to element-wise kernel template. It allows kernel to not only operate on input tensor value, but also each values index(view as 1d, so from 0 to numel) within the lambda.
Direct use case here is to replace thrust::tabulate used in range/arange/linspace. Benifits are:
- thrust::tabulate causes additional unneccessary synchronization on cpu.
- Now it works with tensor iterator, output no longer needs to be contiguous and a memcpy is saved
It can also potentially be reused to add new function to pytorch later, if we see use case both value and index is needed.(for example unify tril/triu into tensor iterator element-wise? add other pattern?)
Known issues:
https://github.com/pytorch/pytorch/pull/23586 is needed to enable non-contiguous case work properly, since overlapping needs to be checked. Currently non-contiguous tensor falls into TOO_HARD. I could write proper check in this file but I figured using exist method is better. jjsjann123
It does not work beyond 32bit indexing. But thrust was erroring on those case too. We could split tensor in caller to enable this. Index changes after split, so it is easier for caller to pass different lambda, and harder for the template to handle it in general.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28175
Differential Revision: D18708649
Pulled By: ngimel
fbshipit-source-id: 382081c96f266ae7b61095fc1f2af41c6b210fa9