[Pallas] Make a FlashAttention Wrapper (#6785)
Summary:
This pull request introduces a FlashAttention wrapper that aims to:
1. Override some default settings for the best performance out of box.
2. Ease the UX such that users don't need to do all the custom_kernel paper works.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper