Fix argmin/max bug (#38946)
Summary:
Fix https://github.com/pytorch/pytorch/issues/38922
# Reproduction
- This is correct
```py
>>> torch.zeros(1, 32767).argmax(dim=0)
tensor([0, 0, 0, ..., 0, 0, 0])
```
- But this is not
```py
>>> torch.zeros(1, 32768).argmax(dim=0)
tensor([ 0, 0, 0, ..., 31141, 31141, 31141])
```
- Only occurs when the size of the reduced dimension is 1
```py
>>> torch.zeros(2, 327680).argmax(dim=0)
tensor([1, 1, 1, ..., 1, 1, 1])
>>> torch.zeros(3, 327680).argmax(dim=0)
tensor([2, 2, 2, ..., 2, 2, 2])
```
- Has something to do with the rest of the dims
```py
>>> torch.zeros(1, 327680).argmax(dim=0)
tensor([ 0, 0, 0, ..., 311296, 311296, 311296])
```
```py
>>> torch.zeros(1, 32768, 10).argmax(dim=0)
tensor([[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
...,
[311296, 311296, 311296, ..., 311296, 311296, 311296],
[311296, 311296, 311296, ..., 311296, 311296, 311296],
[311296, 311296, 311296, ..., 311296, 311296, 311296]])
```
# Reason
- `resize_outputs_` is set to `false` in `reduce_op`, but the dimension is still coalesced during `TensorIterator::build()`
https://github.com/pytorch/pytorch/blob/899a075b25300460614169394588b22937a900f4/aten/src/ATen/native/TensorIterator.cpp#L703-L715
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38946
Differential Revision: D21751663
Pulled By: ngimel
fbshipit-source-id: 6d55e4bb783423b4c2df09cd3e8b87147efcbfdb