pytorch
1ef99263 - Make c10::complex the C++ type for complex tensors (#37421)

Commit
5 years ago
Make c10::complex the C++ type for complex tensors (#37421) Summary: # Overview This PR changes the backing type of complex tensors in `ScalarType` from `std::complex` to `c10::complex`. Since `c10::complex` and `std::complex` are reinterpret-castable, we can freely use `std::complex *` to access `c10::complex` data and vice versa. The implementation of `c10::complex` is not complete yet, so we are reinterpret casting all complex data to `std::complex` during dispatch, and do all operations in `std::complex`. # `std::complex` and `c10::complex` interoperatability To use `std::complex *` to access `c10::complex` data, the following specializations are added: ```C++ template <> inline std::complex<float>* Tensor::data_ptr(); template <> inline std::complex<double>* Tensor::data_ptr(); template <> inline std::complex<float> Tensor::item(); template <> inline std::complex<double> Tensor::item(); ``` See [`aten/src/ATen/templates/TensorMethods.h`](https://github.com/pytorch/pytorch/pull/37274/files#diff-0e8bf6f5024b32c240a4c1f0b4d8fd71) And ```C++ template <> inline std::complex<float> Scalar::to(); template <> inline std::complex<double> Scalar::to(); ``` is added in [`c10/core/Scalar.h`](https://github.com/pytorch/pytorch/pull/37274/files#diff-aabe1c134055c8dcefad830c1c7ae957) # Dispatch Macros in [`Dispatch.h`](https://github.com/pytorch/pytorch/pull/37274/files#diff-737cfdab7707be924da409a98d46cb98) still using `std::complex` as its type. We will add macros such as `AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3` as needed during the migration and not in this PR. Note that `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3` is only used in copy kernel of CUDA, and this PR is already changing it to use `c10::complex` because CUDA copy kernel has to use its original dtype otherwise there will be funny casting of dtypes causing cuda unspecified launch error. When all the migration is done, the c10 version of macros will be removed, and the default version will have `std::complex` replaced by `c10::complex` by default. This design allows us to incrementally migrate from `std::complex` to `c10::complex`. # Note Note that the `std::complex` is not completely replaced by `c10::complex` in c10 yet, for example `c10::Scalar` is still using `std::complex`. This will be fixed in later PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/37421 Differential Revision: D21282161 Pulled By: anjali411 fbshipit-source-id: 635e309e8c8a807c2217723ad250b5ab5a20ce45
Author
Parents
Loading