pytorch
682c0d26 - Use segment/scatter_reduce to support masked reductions on sparse CSR tensors (mean, amax, amin) (fp only) (#78918)

Commit
2 years ago
Use segment/scatter_reduce to support masked reductions on sparse CSR tensors (mean, amax, amin) (fp only) (#78918) Follows design [here](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp#L804-L837) and [here](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp#L885-L928) from SparseCsrTensorMath.cpp (which has already been used to implement sum/prod) but use `segment_reduce`/`scatter_reduce` for reduction step Pull Request resolved: https://github.com/pytorch/pytorch/pull/78918 Approved by: https://github.com/cpuhrsch
Committer
Parents
Loading