Sparse CSR CUDA: Add torch.sparse.sampled_addmm (#68007)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68007
This PR adds a new function to the sparse module.
`sampled_addmm` computes α*(A @ B) * spy(C) + β*C, where C is a sparse CSR matrix and A, B are dense (strided) matrices.
This function is currently restricted to single 2D matrices, it doesn't support batched input.
cc nikitaved pearu cpuhrsch IvanYashchuk
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D32435799
Pulled By: cpuhrsch
fbshipit-source-id: b1ffac795080aef3fa05eaeeded03402bc097392