Fix dispatch of argmax/argmin. (#32961)
Summary:
The way we currently dispatch argmax/argmin to out-of-source devices is bad and caused issues, e.g it doesn't work well when the input requires grad. https://github.com/pytorch/xla/issues/1585.
Making argmax/argmin dispatch at device level resolves it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32961
Differential Revision: D19726826
Pulled By: ailzhang
fbshipit-source-id: f7fb445fd8e7691524afcc47d24d8e6b0171d10c