pytorch
06fbd5dc - [inductor] Fix argmin/max with duplicate values (#100573)

Commit
1 year ago
[inductor] Fix argmin/max with duplicate values (#100573) Fixes #99879 This adds `minimum_with_index` helper functions to compute the minimum value and index simultaneously, with a preference for the smaller index which is required to match eager in case of duplicates. I also remove the mask-and-sum hack with a `tl.reduce` using the previously mentioned helper. This additionally fixes the indices being added together in the case of duplicates. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100573 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading