xla
887d3446 - [Pallas] Improve FlashAttention segment_ids test case (#7034)

Comment changes are shownComment changes are hidden
Commit
1 year ago
[Pallas] Improve FlashAttention segment_ids test case (#7034) Summary: Make the test case more useful. The original test case will just create a mask that select everything, lol. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py
Author
Parents
  • test
    • File
      test_pallas.py