pytorch
e29d8477 - Added CUDA support for torch.orgqr (#51348)

Commit
3 years ago
Added CUDA support for torch.orgqr (#51348) Summary: This PR adds support for CUDA inputs for `torch.orgqr`. CUDA implementation is based on both [cuSOLVER](https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-orgqr) and MAGMA. cuSOLVER doesn't have a specialized routine for the batched case. While MAGMA doesn't have a specialized GPU native (without CPU sync) `orgqr`. But MAGMA has implemented (and not documented) the batched GPU native version of `larft` function (for small inputs of size <= 32), which together with `larfb` operation form `orgqr` (see the call graph [here at the end of the page](http://www.netlib.org/lapack/explore-html/da/dba/group__double_o_t_h_e_rcomputational_ga14b45f7374dc8654073aa06879c1c459.html)). So now there are two main codepaths for CUDA inputs (if both MAGMA and cuSOLVER are available): * if `batchsize > 1` and `tau.shape[-1] <= 32` then MAGMA based function is called * else [cuSOLVER's `orgqr`](https://docs.nvidia.com/cuda/cusolver/index.html#cuSolverDN-lt-t-gt-orgqr) is used. If MAGMA is not available then only cuSOLVER is used and vice versa. Documentation updates and possibly a new name for this function will be in a follow-up PR. Ref. https://github.com/pytorch/pytorch/issues/50104 Pull Request resolved: https://github.com/pytorch/pytorch/pull/51348 Reviewed By: ngimel Differential Revision: D26727918 Pulled By: mruberry fbshipit-source-id: 1c4d15fa76ba624e341a69a32337a9a16cc01013
Author
Parents
Loading