[DO NOT REVIEW YET] Extend paged attention #8237
add the reference test.
b1a26e76
moved the reference impl to the custom_kernel module
f712b344
add dimention annotation to the riginal paged_attention
cf6dcf54
add the new api extended_paged_attention
43453dab
create a new extended_paged_attention api with a flag controlling if …
bf71c8ce
Create a test that call both non-kernel extended_paged_attention and …
da7150ba
incorporate Woosuk's ref impl and overhaul the test and the service c…
f485878e
add the original paged_attention to the torch_xla and made sure torch…
7df596e5
modified the hardcode number in the test test_extended_paged_attentio…
3eb8e338
The kernel caller is set up so we are ready to write the kernel.
830388d8
use the original paged attention in torch_xla folder
83528de0
add my comments to the original paged attention kernel
fbca5cf0
Added more comments to the original paged attention kernel.
ba30b9bd
added reference extended paged attention impl and the test for the or…
43c2bf00
Implementing kernel v0
58fe2576
vanbasten23
force pushed
from
b4454182
to
7be26ae4
1 year ago
vanbasten23
force pushed
from
7be26ae4
to
dd6dd997
1 year ago
vanbasten23
force pushed
from
dd6dd997
to
c8a012ba
1 year ago
finished implementing the v0. Also add a test that use 1 query token …
669d5981
vanbasten23
force pushed
from
c8a012ba
to
669d5981
1 year ago
Something wrong with the test. Now the test test_extended_paged_atten…
290ab57b
added a few more tests.
118fba5c
Account for the query index for causal
54f0af18
revised v0 implementation. Add partly finished v1 impl. Also added mo…
3d9e3596
Set up the grid, inspec, outspec, outshape for the kernel v1.
d282b2ea
created a first version of v1.
069ca31e
replaced the second forloop from looping over num_q_head to num_kv_heads
8ce1bb3a
add flash_mqa code and test for experiments
2e839cb0
upload everything
6645e7b5
fixed some syntax error
fc0b345d
fixed some syntax error. Now it hits a runtime error
afb97ae9
fix an error
5a6ff8fd
added the causal mask
d6b994a8
fixed the blocker issue that pltpu.repeat(acc_scale, acc_scale_repeat…
0cca1105
test the case where both num_block_q and num_block_kv are 1.
c33da032
Start using dma but it fails.
e8ccd04e
fixed a bug
92672e44
added a jax ref impl which will be used in google3
9834f060
most basic test passed.
35a3c55e
vanbasten23
force pushed
from
e5535472
to
35a3c55e
1 year ago
refined test. Also replace reshape with permute_dims. Now many tests …
bb79ead4
when we write to o_ref, don't check @pl.when(kv_blk_idx == num_kv_blk…
8305949d
add some comment to torch_xla/experimental/pallas_kernels/flash_mqa.py
5d28a68b
all test passed. finally
0c49d5aa
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub