pytorch
ac6ec0ef - [ROCM] fix bug in #60313 (#61073)

Commit
3 years ago
[ROCM] fix bug in #60313 (#61073) Summary: This PR fixes a bug in https://github.com/pytorch/pytorch/issues/60313. Where the tensors generated by _generate_valid_rocfft_input are on the cpu instead of the gpu. This was due to using numpy to generate tensors and converting it to pytorch using torch.from_numpy. This leads to the generated tensors staying on the cpu. We now generate the tensors using pytorch itself which carries over the device type of the input tensors to the generated tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61073 Reviewed By: H-Huang Differential Revision: D29668418 Pulled By: malfet fbshipit-source-id: ce2025c26d079c15603a89b9bf7878f48d73155e
Author
Parents
Loading