jax
904f34a9 - Correctly handle bfloat16 input and output from jax2tf functions. (#3652)

Commit
5 years ago
Correctly handle bfloat16 input and output from jax2tf functions. (#3652) TF and JAX have different NumPy dtypes for bfloat16 so we need to be careful to use the right version. I think there are a few other cases in jax2tf where we should be using `to_tf_dtype` rather than passing `v.dtype` directly into tf ops (e.g. I am a bit surprised to only update convert_element_type_p), however I think a follow up adding tests for those cases would be best.
Author
Parents
Loading