xla
[DO NOT REVIEW YET] Extend paged attention
#8237
Open

[DO NOT REVIEW YET] Extend paged attention #8237

vanbasten23 wants to merge 39 commits into master from xiowei/extend_paged_attention
vanbasten23
vanbasten23 add the reference test.
b1a26e76
vanbasten23 moved the reference impl to the custom_kernel module
f712b344
vanbasten23 add dimention annotation to the riginal paged_attention
cf6dcf54
vanbasten23 add the new api extended_paged_attention
43453dab
vanbasten23 create a new extended_paged_attention api with a flag controlling if …
bf71c8ce
vanbasten23 Create a test that call both non-kernel extended_paged_attention and …
da7150ba
vanbasten23 incorporate Woosuk's ref impl and overhaul the test and the service c…
f485878e
vanbasten23 add the original paged_attention to the torch_xla and made sure torch…
7df596e5
vanbasten23 modified the hardcode number in the test test_extended_paged_attentio…
3eb8e338
vanbasten23 The kernel caller is set up so we are ready to write the kernel.
830388d8
vanbasten23 use the original paged attention in torch_xla folder
83528de0
vanbasten23 add my comments to the original paged attention kernel
fbca5cf0
vanbasten23 Added more comments to the original paged attention kernel.
ba30b9bd
vanbasten23 added reference extended paged attention impl and the test for the or…
43c2bf00
vanbasten23 Implementing kernel v0
58fe2576
vanbasten23 vanbasten23 force pushed from b4454182 to 7be26ae4 1 year ago
vanbasten23 vanbasten23 force pushed from 7be26ae4 to dd6dd997 1 year ago
vanbasten23 vanbasten23 force pushed from dd6dd997 to c8a012ba 1 year ago
vanbasten23 finished implementing the v0. Also add a test that use 1 query token …
669d5981
vanbasten23 vanbasten23 force pushed from c8a012ba to 669d5981 1 year ago
vanbasten23 Something wrong with the test. Now the test test_extended_paged_atten…
290ab57b
vanbasten23 added a few more tests.
118fba5c
vanbasten23 Account for the query index for causal
54f0af18
vanbasten23 revised v0 implementation. Add partly finished v1 impl. Also added mo…
3d9e3596
vanbasten23 Set up the grid, inspec, outspec, outshape for the kernel v1.
d282b2ea
vanbasten23 created a first version of v1.
069ca31e
vanbasten23 replaced the second forloop from looping over num_q_head to num_kv_heads
8ce1bb3a
vanbasten23 add flash_mqa code and test for experiments
2e839cb0
vanbasten23 upload everything
6645e7b5
vanbasten23 fixed some syntax error
fc0b345d
vanbasten23 fixed some syntax error. Now it hits a runtime error
afb97ae9
vanbasten23 fix an error
5a6ff8fd
vanbasten23 added the causal mask
d6b994a8
vanbasten23 fixed the blocker issue that pltpu.repeat(acc_scale, acc_scale_repeat…
0cca1105
vanbasten23 test the case where both num_block_q and num_block_kv are 1.
c33da032
vanbasten23 Start using dma but it fails.
e8ccd04e
vanbasten23 fixed a bug
92672e44
vanbasten23 added a jax ref impl which will be used in google3
9834f060
vanbasten23 most basic test passed.
35a3c55e
vanbasten23 vanbasten23 force pushed from e5535472 to 35a3c55e 1 year ago
vanbasten23 refined test. Also replace reshape with permute_dims. Now many tests …
bb79ead4
vanbasten23 when we write to o_ref, don't check @pl.when(kv_blk_idx == num_kv_blk…
8305949d
vanbasten23 add some comment to torch_xla/experimental/pallas_kernels/flash_mqa.py
5d28a68b
vanbasten23 all test passed. finally
0c49d5aa

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone