pytorch
eead5990 - Extend CSR constructor to support batched indices and values

Commit
2 years ago
Extend CSR constructor to support batched indices and values This is the first portion of changes required to enable Batched CSR format described in https://github.com/pytorch/pytorch/issues/60854#batched-CSR-computation. Currently, only the same batch shape for indices and values is allowed. In the future, we could enable "broadcasting" of indices and batched values, as done in xFormers (https://github.com/facebookresearch/xformers/blob/dd96b8d8beda5308fb433c1ef3ff04b7f178c263/xformers/components/attention/_sputnik_sparse.py#L441). This PR adds possibility to construct a batched CSR matrix with `torch.sparse_csr_tensor` and this batched CSR can be converted to a dense tensor with a `.to_dense()` call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74542 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading