pytorch
ada65fdd - [complex32] fft support (cuda only) (#74857)

Commit
2 years ago
[complex32] fft support (cuda only) (#74857) `half` and `complex32` support for `torch.fft.{fft, fft2, fftn, hfft, hfft2, hfftn, ifft, ifft2, ifftn, ihfft, ihfft2, ihfftn, irfft, irfft2, irfftn, rfft, rfft2, rfftn}` * We only add support for `CUDA` as `cuFFT` supports these precision. * We still error out on `CPU` and `ROCm` as their respective backends don't support this precision For `cuFFT` following are the constraints for these precisions * Minimum GPU architecture is SM_53 * Sizes are restricted to powers of two only * Strides on the real part of real-to-complex and complex-to-real transforms are not supported * More than one GPU is not supported * Transforms spanning more than 4 billion elements are not supported Ref: https://docs.nvidia.com/cuda/cufft/#half-precision-transforms TODO: * [x] Update docs about the restrictions * [x] Check the correct way to check for `hip` device. (seems like `device.is_cuda()` is true for hip as well) (Thanks @peterbell10 ) Ref for second point in TODO:https://github.com/pytorch/pytorch/blob/e424e7d214c02a79aab84d2d096b5b014f73ccd7/aten/src/ATen/native/SpectralOps.cpp#L31 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74857 Approved by: https://github.com/anjali411, https://github.com/peterbell10
Author
Committer
Parents
Loading