🚨🚨[core] Completely rewrite the masking logic for all attentions #37866
Cyrilvallez
changed the title Refactor mask [core] Completely rewrite the masking logic for all attentions 1 year ago
Cyrilvallez
force pushed
from
7b31e00d
to
5c14d588
1 year ago
Cyrilvallez
force pushed
from
a195aa13
to
53ca556d
1 year ago
Cyrilvallez
force pushed
from
53ca556d
to
ce42aa74
1 year ago
Cyrilvallez
marked this pull request as ready for review 1 year ago
gante
commented
on 2025-05-16
vasqu
commented
on 2025-05-16
Cyrilvallez
force pushed
from
0b6bbe59
to
7fc4f910
1 year ago
Cyrilvallez
force pushed
from
28e232cb
to
5170e9d1
1 year ago
Cyrilvallez
changed the title [core] Completely rewrite the masking logic for all attentions 🚨🚨[core] Completely rewrite the masking logic for all attentions 1 year ago
Cyrilvallez
force pushed
from
13f9d2ea
to
6a961044
1 year ago
Cyrilvallez
force pushed
from
dee568cc
to
4a2e906f
1 year ago
start
e083d5c9
start having a clean 4d mask primitive
e1d43c47
Update mask_utils.py
59a69c4a
Update mask_utils.py
8aa61b0b
switch name
ee7bafdd
Update masking_utils.py
bfcc5d84
add a new AttentionMask tensor class
f92757a5
fix import
932c17b0
nits
1227356b
fixes
542d0543
use full and quandrants
99235bb8
general sdpa mask for all caches
6d16c6b2
style
c98bc686
start some tests
f3027fed
tests with sliding, chunked
7eeed31a
add styling
ddd60596
test hybrid
9397d176
Update masking_utils.py
bb6ea159
small temp fixes
d7c4fa76
Update modeling_gemma2.py
3ea388c9
compile compatible
21652329
Update masking_utils.py
07fda4f7
improve
eed63837
start making it more general
15485c30
Update masking_utils.py
ce4080b9
generate
039e4440
make it work with flex style primitives!
8b99dde4
Update masking_utils.py
14f01634
Update masking_utils.py
6bb87425
Update masking_utils.py
1c16acb5
improve
8f75abba
Update cache_utils.py
387cdb46
Update masking_utils.py
346bfa99
simplify - starting to look good!
b826d915
Update masking_utils.py
83f20d37
name
6fa1d35e
Update masking_utils.py
05777fbd
style
6ab437e9
Update masking_utils.py
1d6c9009
Update masking_utils.py
b5e1ebd2
Update masking_utils.py
26a428cc
Update masking_utils.py
71162e49
small fix for flex
28d5a190
flex compile
85320339
FA2
f3c8e7cd
Update masking_utils.py
d0c0b40e
Escape for TGI/vLLM!
bbc6bece
Update masking_utils.py
ada16276
Update masking_utils.py
1620f1e2
Update masking_utils.py
8f3a2a03
General case without cache
d5b3285c
rename
ce227282
full test on llama4
7bf1352b
small fix for FA2 guard with chunk
05529fff
Update modeling_gemma2.py
5afd8983
post rebase cleanup
67bea3f4
FA2 supports static cache!
7bb501d1
Update modeling_flash_attention_utils.py
9bb62192
Update flex_attention.py
f4849ab0
Update masking_utils.py
44f9b657
Update masking_utils.py
dc52eb36
Update utils.py
d2f645db
override for export
07bd06c3
Update executorch.py
f6735f2d
Update executorch.py
f2d8a543
Update executorch.py
ee0afddb
Update executorch.py
73549e57
Update masking_utils.py
7031bc40
Update masking_utils.py
552b586e
output attentions
d2fb4de2
style
628dcd89
Update masking_utils.py
27fb93f0
Update executorch.py
a0915177
Add doicstring
cbf1144f
Add license and put mask visualizer at the end
59eb3cc4
Update test_modeling_common.py
85ab5da6
fix broken test
daf5bee7
Update test_modeling_gemma.py
201da65c
Update test_modeling_gemma2.py
f06c7cd2
FA2
f3c8e7cd
Update flex_attention.py
f4849ab0
Update executorch.py
f6735f2d
Update executorch.py
f2d8a543
fixes
f3906750
make it better
e26eb845
generalize to other test models
bca422c3
fix
a66697e1
Update masking_utils.py
7ea8db75
fix
1d3751fb
do not check mask equivalence if layer types are different
df439179
executorch
095746a4
Update modeling_gemma2.py
770422c5
Update masking_utils.py
0b5a8176
use layer_idx instead
cf5212c9
adjust
e28d6638
Update masking_utils.py
53e9f472
test
8e2bdd11
fix imports
558c47e0
Update modeling_gemma2.py
df49780a
other test models
a87f7ddb
Update modeling_llama4.py
8426b341
Update masking_utils.py
413d446f
improve
7f0f9898
simplify
3ed17a2c
Update masking_utils.py
f23236d7
typos
0ffff1da
typo
09d32dfd
fix
e20ebabc
Update masking_utils.py
d273325e
default DynamicCache
5ae049cd
remove default cache
326bacf3
simplify
d58eaab7
Update masking_utils.py
02a91802
Update masking_utils.py
d67de199
Update masking_utils.py
3831ccc8
Update masking_utils.py
4b54f187
simplify
6edf1166
Update masking_utils.py
18614a59
Update masking_utils.py
bd931a07
Update masking_utils.py
93f8d82c
export
711ab9b1
Update executorch.py
58f198e9
Update executorch.py
9c69ae5e
Update flex_attention.py
4e405160
Update executorch.py
6a28a34a
upstream to modular gemma 1 & 2
c70bf3c6
Update modular_mistral.py
3a972d47
switch names
7ca132df
use dict
34a55c51
put it in the Layer directly
5c89d722
update copy model source for mask functions
e6891b63
apply so many modular (hopefully 1 shot)
ac021700
use explicite dicts for make style happy
59e11ab6
protect import
27041e08
check docstring
0cf18e21
better default in hybrid caches
47158dfd
qwens
022c4a9b
Update modular_qwen2.py
94896dc9
simplify core logic!
9bbe1cbd
Update executorch.py
0844a491
qwen3 moe
dbbecde5
Update masking_utils.py
a3502631
Update masking_utils.py
09b01489
simplify a lot sdpa causal skip
fcd21a40
Update masking_utils.py
8cb637f7
post-rebase
481f0862
gemma3 finally
91c87f8b
style
9bda8648
check it before
d24309f2
gemma3
8e153a14
More general with newer torch
ebc7f9d0
align gemma3
31008ba8
Update utils.py
3c385ea6
Update utils.py
b206cd5a
Update masking_utils.py
b0850bfa
Update test_modeling_common.py
79eac774
Update flex_attention.py
29a6bc25
Update flex_attention.py
bb2dda0a
Update flex_attention.py
1b85bbbc
test
f76df19b
executorch
3ff39084
Update test_modeling_common.py
fd8a6a21
Update masking_utils.py
84db8eea
Update masking_utils.py
83ba79f9
Update masking_utils.py
acbe4be9
Update masking_utils.py
b0333def
Update executorch.py
3c483348
Update test_modeling_common.py
cfd06948
fix copies
01810426
device
ad5fb366
sdpa can be used without mask -> pass the torchscript tests in this case
b477c1ef
Use enum for check
3b71b7bc
revert enum and add check instead
1a05ca1c
remove broken test
2029cfa2
cohere2
28d62da9
some doc & reorganize the Interface
9d7bd3a4
Update tensor_parallel.py
343ab956
Cyrilvallez
force pushed
from
ffdd1424
to
343ab956
1 year ago
Update tensor_parallel.py
78a21ea3
doc and dummy
4c87caa7
Update test_modeling_paligemma2.py
1f21213f
Update modeling_falcon_h1.py
e3530673
Update masking_utils.py
7979ac66
executorch patch
ba6501c5
style
269969ef
CIs
75ccf7ab
use register in executorch
7bcd55f6
final comments!
9245fcda
Cyrilvallez
deleted the refactor-mask branch 1 year ago
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub