pytorch
700b39a3 - Sparse CSR CUDA: add `torch.addmm` with all inputs sparse (#63511)

Commit
3 years ago
Sparse CSR CUDA: add `torch.addmm` with all inputs sparse (#63511) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63511 This PR adds `torch.addmm(c, a, b)` variant with `c, a, b` all being CSR tensors. The underlying cuSPARSE function works only with 32-bit indices, and in the current implementation the result tensor has 32-bit indices. Input tensors can have both 64-bit and 32-bit indices tensors. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D31809838 Pulled By: cpuhrsch fbshipit-source-id: 97005dba27d8adcae445eb756bcbd7271061e9b5
Author
Parents
Loading