CUDA-strided-complex Binary and Unary Op support (#30295)
Summary:
In-tree changes to pytorch to support complex numbers are being submitted here.
Out-of-tree support for CUDA complex numbers is here: [pytorch-cuda-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex)
Changes so far:
- [x] Added complex support of torch.empty and torch.fill()
- [x] Added complex support of CopyKernels
- The 'static_cast_with_inter_type' template function is specialized for the following cases
- `dest_t = thrust::complex<dest_value_t>`, `src_t = std::complex<src_value_t>`
- `dest_t = std::complex<dest_value_t>`, `src_t = thrust::complex<src_value_t>`
- This handles the compile-time case where `dest_value_t=double` and `src_value_t=float`.
- [x] Added complex support of BinaryOp kernels
- `using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;` converts std::complex<T> ScalarTypes to thrust types and is a no-op of other Scalar Types.
- The operator is performed using complex number support defined in `thrust/complex.h`
- This could be extended to work with ROCm by using `rocm/complex.h`
- [x] Added complex support of UnaryOp kernels
- Added CUDA support for `angle()`, `real()`, `imag()`, `conj()`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30295
Differential Revision: D18781954
Pulled By: ezyang
fbshipit-source-id: 25d204c0b8143ee27fda345a5d6a82f095da92a7