pytorch
0ec3c5bc - [MPS] Reduce ops multi axes support (#91734)

Commit
2 years ago
[MPS] Reduce ops multi axes support (#91734) Currently, most of the reduction ops are flattening the input tensor to 1D to perform the operation. This change removes the flattening of the tensors / the unranked placeholders and adds support for multi axes in all the reduction ops. - Fixes reduction ops with correctness and shape issues. - Fixes masked.argmax / masked.argmin. In case of passing inf to argmax / argmin, MPS will return nan as index for these numbers. Casting this nan to Long will make it -1. This change avoids negative values by clamping them to 0 (matching CPU results). TestConsistency issues fixed: ``` std var amax amin sum prod mean count_nonzero masked.amax masked.amin masked.mean masked.prod masked.std masked.sum ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91734 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading