Correctly handle bfloat16 input and output from jax2tf functions. (#3652)
TF and JAX have different NumPy dtypes for bfloat16 so we need to be careful to
use the right version. I think there are a few other cases in jax2tf where we
should be using `to_tf_dtype` rather than passing `v.dtype` directly into tf
ops (e.g. I am a bit surprised to only update convert_element_type_p), however I
think a follow up adding tests for those cases would be best.