move Tensor comparisons back to C (#48018)
Summary:
It seems that the machinery to handle comparison method in C rather than Python already exists, unless I'm missing something. (There is a wrapper for `TypeError_to_NotImplemented_`, and Python code gen handles `__torch_function__` which are the two things `_wrap_type_error_to_not_implemented` is doing) The performance change is quite stark:
```
import torch
from torch.utils.benchmark import Timer
global_dict = {
"x": torch.ones((2, 2)),
"y_scalar": torch.ones((1,)),
"y_tensor": torch.ones((2, 1)),
}
for stmt in ("x == 1", "x == y_scalar", "x == y_tensor"):
print(Timer(stmt, globals=global_dict).blocked_autorange(min_run_time=5), "\n")
```
### Before:
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f3d1289dc10>
x == 1
Median: 12.86 us
IQR: 0.65 us (12.55 to 13.20)
387 measurements, 1000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f3d1289d1d0>
x == y_scalar
Median: 6.03 us
IQR: 0.33 us (5.91 to 6.24)
820 measurements, 1000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f3d2b9e2050>
x == y_tensor
Median: 6.34 us
IQR: 0.33 us (6.16 to 6.49)
790 measurements, 1000 runs per measurement, 1 thread
```
### After:
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbdba2a16d0>
x == 1
Median: 6.88 us
IQR: 0.40 us (6.74 to 7.14)
716 measurements, 1000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbdd2e07ed0>
x == y_scalar
Median: 2.98 us
IQR: 0.19 us (2.89 to 3.08)
167 measurements, 10000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fbdd33e4510>
x == y_tensor
Median: 3.03 us
IQR: 0.13 us (2.97 to 3.10)
154 measurements, 10000 runs per measurement, 1 thread
```
There's still a fair bit of work left. Equivalent NumPy is about 6x faster than the new overhead, and PyTorch 0.4 is about 1.25 us across the board. (No scalar cliff.) But it's a start.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48018
Reviewed By: gchanan
Differential Revision: D25026257
Pulled By: robieta
fbshipit-source-id: 093b06a1277df25b4b7cc0d4e585b558937b10a1