Enables bfloat16 x [float16, complex64, complex128] type promotion (#43324)
Summary:
Implements bfloat16 type promotion consistent with JAX (see https://jax.readthedocs.io/en/latest/type_promotion.html), addressing issue https://github.com/pytorch/pytorch/issues/43049.
- bfloat16 x float16 -> float32
- bfloat16 x complex64 -> complex64
- bfloat16 x complex128 -> complex128
Existing tests, after updates, are sufficient to validate the new behavior.
cc xuhdev
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43324
Reviewed By: albanD
Differential Revision: D23259823
Pulled By: mruberry
fbshipit-source-id: ca9c2c7d0325faced1f884f3c37edf8fa8c8b089