[Pallas] Support XLA_USE_BF16 (#6817)
Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but torch.tensor wrappers will still return torch.float. To address this, we need to set the jax tracers correctly to produce the correct Mosaic.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16