Moving sign function to ATen (#22861)
Summary:
This PR linked to https://github.com/pytorch/pytorch/issues/22806 moving sign function to ATen.
sign(x) supports bool, and vectorized operation on CPU.
sign(NaN) is defined to return 0.
sign(bool) is a no-op, the resulting tensor will holds the same values than the input one.
- [x] CPU Backend
- [x] CUDA Backend
- [x] Bring support for bool dtype
- [x] Bring support for Half dtype
- [x] Add test for NaN
- [x] Add test for bool dtype
- [x] Delete legacy implementation in THTensorMoreMath.cpp
Performances:
```python
timeit -s 'import torch; x = torch.randn((1000, 1000))' -n 1000 'torch.sign(x)'
timeit -s 'import torch; x = torch.randn((1000, 1000), device="cuda")' -n 1000 'torch.sign(x); torch.cuda.synchronize()'
```
| device | before | after |
| :-------------: | :-------------: | :-----: |
| CPU | 1.24 msec | 33.9 usec |
| GPU | 680 usec | 7.13 usec |
| CPU (1 thread) | 0.82 msec | 0.73 msec |
| GPU (1 thread) | 16.1 used | 15.9 usec |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22861
Differential Revision: D16503452
Pulled By: VitalyFedyunin
fbshipit-source-id: a87ce7fff139642ef4ed791f15873074ad0d53af