Low precision support for jiterator (#70157)
Summary:
This adds support for bfloat16 and fp16 types for jiterator by adding at::Half and at::BFloat16 classes to the jiterator code template. The only methods defined in those classes are construction from float and implicit conversion to float. Mathematical operations on them never need to be defined, because jiterator is written in a way to implicitly upcast the inputs to the functor, so all math has to be performed on float only (e.g. compute part of the kernel would always be written as
```
out[j] = i0<float>(arg0[j]);
```
It also adds support for casting to complex outputs, by adding a similar templated class c10::complex<T>. Originally I planned to only support float -> complex complex for it, but to compile fetch_and_cast function we also need complex -> float conversion. We can avoid it by compiling fetch_and_cast for a different subset of types, but I'm not doing it in this PR. Thus, technically, we can compile a kernel that would accept complex inputs and produce wrong results, but we are guarding against it by static asserting that none of the functor datatype are complex, and runtime-checking that none of the inputs are complex.
Adding bfloat16, half and complex support allows us to remove special handling for type promotion tests for gcd.
i0 (that supports half and bfloat16 inputs) is moved to use jiterator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70157
Reviewed By: mruberry
Differential Revision: D33221645
Pulled By: ngimel
fbshipit-source-id: 9cfe8aba3498a0604c4ea62c217292ea06c826b1