[inductor] Fix argmin/max with duplicate values (#99920)
Fixes #99879
This adds `minimum_with_index` helper functions to compute the minimum
value and index simultaneously, with a preference for the smaller
index which is required to match eager in case of duplicates.
I also remove the mask-and-sum hack with a `tl.reduce` using
the previously mentioned helper. This additionally fixes the indices
being added together in the case of duplicates.
As an example, this is the kernel generated for `torch.argmin(x, 1)`:
```python
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1028 # dynamic_shapes=False
rnumel = 1028 # dynamic_shapes=False
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp1 = tl.full([XBLOCK, RBLOCK], float("inf"), tl.float32)
_tmp1_index = tl.full([XBLOCK, RBLOCK], 9223372036854775807, tl.int64)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (1028*x0)), rmask & xmask, eviction_policy='evict_last', other=0)
_tmp1_next, _tmp1_index_next = triton_helpers.minimum_with_index(
_tmp1, _tmp1_index, tmp0, rindex
)
_tmp1 = tl.where(rmask & xmask, _tmp1_next, _tmp1)
_tmp1_index = tl.where(rmask & xmask, _tmp1_index_next, _tmp1_index)
_, tmp1_tmp = triton_helpers.min_with_index(_tmp1, _tmp1_index, 1)
tmp1 = tmp1_tmp[:, None]
tl.store(out_ptr0 + x0, tmp1, xmask)
```
Or for a persistent reduction, it generates:
```python
tmp0 = tl.load(in_ptr0 + (r1 + (1024*x0)), rmask & xmask, other=0)
tmp2 = tl.where(rmask & xmask, tmp0, float("inf"))
tmp3 = tl.broadcast_to(rindex, tmp2.shape)
_, tmp4_tmp = triton_helpers.min_with_index(tmp2, tmp3, 1)
tmp4 = tmp4_tmp[:, None]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99920
Approved by: https://github.com/ngimel