compute common dtype based on inputs only (#25593)
Summary:
Currently we compute common dtype for TensorIterator based on all inputs and outputs. It can be a problem when dtype of the outputs should be different from dtype of inputs. (Example torch.eq)
We also have `dont_compute_common_dtype` method that allows us to avoid a computation of a common dtype for all inputs and outputs.
This PR will give the ability to compute common dtype based only on inputs using `compute_common_dtype_only_for_inputs`. Also it will provide a simple method `input_dtype(int arg=0) that will give the ability to dispatch based on input's dtype.
```
AT_DISPATCH_ALL_TYPES(iter.input_dtype(), ...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25593
Differential Revision: D17286352
Pulled By: ifedan
fbshipit-source-id: a94fb608acd2763120992fe85b8dfd02ff21f9ba