pytorch
d1beda53 - Sparse CSR CUDA: add batched support for torch.sparse.sampled_addmm

Commit
2 years ago
Sparse CSR CUDA: add batched support for torch.sparse.sampled_addmm This PR adds a forloop around cuSPARSE calls to support batched inputs. cuSPARSE function itself doesn't support batched inputs yet. `mat1` and `mat2` must have the same batch shape. It's allowed to pass `self` as a single matrix when `mat1` and `mat2` are batched. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77243 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading