transformers
🚨🚨[core] Completely rewrite the masking logic for all attentions
#37866
Merged

🚨🚨[core] Completely rewrite the masking logic for all attentions #37866

Cyrilvallez merged 193 commits into main from refactor-mask
Cyrilvallez
github-actions github-actions marked this pull request as draft 1 year ago
github-actions
HuggingFaceDocBuilderDev
Cyrilvallez Cyrilvallez changed the title Refactor mask [core] Completely rewrite the masking logic for all attentions 1 year ago
Cyrilvallez Cyrilvallez force pushed from 7b31e00d to 5c14d588 1 year ago
Cyrilvallez Cyrilvallez force pushed from a195aa13 to 53ca556d 1 year ago
Cyrilvallez Cyrilvallez force pushed from 53ca556d to ce42aa74 1 year ago
Cyrilvallez Cyrilvallez marked this pull request as ready for review 1 year ago
ArthurZucker
ArthurZucker commented on 2025-05-13
ydshieh
ArthurZucker
ArthurZucker commented on 2025-05-14
ArthurZucker
ArthurZucker commented on 2025-05-16
gante
gante commented on 2025-05-16
vasqu
vasqu commented on 2025-05-16
Cyrilvallez Cyrilvallez force pushed from 0b6bbe59 to 7fc4f910 1 year ago
Cyrilvallez
ArthurZucker
ArthurZucker approved these changes on 2025-05-19
Cyrilvallez Cyrilvallez force pushed from 28e232cb to 5170e9d1 1 year ago
Cyrilvallez Cyrilvallez changed the title [core] Completely rewrite the masking logic for all attentions 🚨🚨[core] Completely rewrite the masking logic for all attentions 1 year ago
Cyrilvallez Cyrilvallez force pushed from 13f9d2ea to 6a961044 1 year ago
Cyrilvallez Cyrilvallez force pushed from dee568cc to 4a2e906f 1 year ago
Cyrilvallez start
e083d5c9
Cyrilvallez start having a clean 4d mask primitive
e1d43c47
Cyrilvallez Update mask_utils.py
59a69c4a
Cyrilvallez Update mask_utils.py
8aa61b0b
Cyrilvallez switch name
ee7bafdd
Cyrilvallez Update masking_utils.py
bfcc5d84
ArthurZucker add a new AttentionMask tensor class
f92757a5
ArthurZucker fix import
932c17b0
ArthurZucker nits
1227356b
ArthurZucker fixes
542d0543
ArthurZucker use full and quandrants
99235bb8
Cyrilvallez general sdpa mask for all caches
6d16c6b2
Cyrilvallez style
c98bc686
Cyrilvallez start some tests
f3027fed
Cyrilvallez tests with sliding, chunked
7eeed31a
ArthurZucker add styling
ddd60596
Cyrilvallez test hybrid
9397d176
Cyrilvallez Update masking_utils.py
bb6ea159
Cyrilvallez small temp fixes
d7c4fa76
Cyrilvallez Update modeling_gemma2.py
3ea388c9
Cyrilvallez compile compatible
21652329
Cyrilvallez Update masking_utils.py
07fda4f7
Cyrilvallez improve
eed63837
Cyrilvallez start making it more general
15485c30
Cyrilvallez Update masking_utils.py
ce4080b9
Cyrilvallez generate
039e4440
Cyrilvallez make it work with flex style primitives!
8b99dde4
Cyrilvallez Update masking_utils.py
14f01634
Cyrilvallez Update masking_utils.py
6bb87425
Cyrilvallez Update masking_utils.py
1c16acb5
Cyrilvallez improve
8f75abba
Cyrilvallez Update cache_utils.py
387cdb46
Cyrilvallez Update masking_utils.py
346bfa99
Cyrilvallez simplify - starting to look good!
b826d915
Cyrilvallez Update masking_utils.py
83f20d37
Cyrilvallez name
6fa1d35e
Cyrilvallez Update masking_utils.py
05777fbd
Cyrilvallez style
6ab437e9
Cyrilvallez Update masking_utils.py
1d6c9009
Cyrilvallez Update masking_utils.py
b5e1ebd2
Cyrilvallez Update masking_utils.py
26a428cc
Cyrilvallez Update masking_utils.py
71162e49
Cyrilvallez small fix for flex
28d5a190
Cyrilvallez flex compile
85320339
Cyrilvallez FA2
f3c8e7cd
Cyrilvallez Update masking_utils.py
d0c0b40e
Cyrilvallez Escape for TGI/vLLM!
bbc6bece
Cyrilvallez Update masking_utils.py
ada16276
Cyrilvallez Update masking_utils.py
1620f1e2
Cyrilvallez Update masking_utils.py
8f3a2a03
Cyrilvallez General case without cache
d5b3285c
Cyrilvallez rename
ce227282
Cyrilvallez full test on llama4
7bf1352b
Cyrilvallez small fix for FA2 guard with chunk
05529fff
Cyrilvallez Update modeling_gemma2.py
5afd8983
Cyrilvallez post rebase cleanup
67bea3f4
Cyrilvallez FA2 supports static cache!
7bb501d1
Cyrilvallez Update modeling_flash_attention_utils.py
9bb62192
Cyrilvallez Update flex_attention.py
f4849ab0
Cyrilvallez Update masking_utils.py
44f9b657
Cyrilvallez Update masking_utils.py
dc52eb36
Cyrilvallez Update utils.py
d2f645db
Cyrilvallez override for export
07bd06c3
Cyrilvallez Update executorch.py
f6735f2d
Cyrilvallez Update executorch.py
f2d8a543
Cyrilvallez Update executorch.py
ee0afddb
Cyrilvallez Update executorch.py
73549e57
Cyrilvallez Update masking_utils.py
7031bc40
Cyrilvallez Update masking_utils.py
552b586e
Cyrilvallez output attentions
d2fb4de2
Cyrilvallez style
628dcd89
Cyrilvallez Update masking_utils.py
27fb93f0
Cyrilvallez Update executorch.py
a0915177
Cyrilvallez Add doicstring
cbf1144f
Cyrilvallez Add license and put mask visualizer at the end
59eb3cc4
Cyrilvallez Update test_modeling_common.py
85ab5da6
Cyrilvallez fix broken test
daf5bee7
Cyrilvallez Update test_modeling_gemma.py
201da65c
Cyrilvallez Update test_modeling_gemma2.py
f06c7cd2
Cyrilvallez FA2
f3c8e7cd
Cyrilvallez Update flex_attention.py
f4849ab0
Cyrilvallez Update executorch.py
f6735f2d
Cyrilvallez Update executorch.py
f2d8a543
Cyrilvallez fixes
f3906750
Cyrilvallez make it better
e26eb845
Cyrilvallez generalize to other test models
bca422c3
Cyrilvallez fix
a66697e1
Cyrilvallez Update masking_utils.py
7ea8db75
Cyrilvallez fix
1d3751fb
Cyrilvallez do not check mask equivalence if layer types are different
df439179
Cyrilvallez executorch
095746a4
Cyrilvallez Update modeling_gemma2.py
770422c5
Cyrilvallez Update masking_utils.py
0b5a8176
Cyrilvallez use layer_idx instead
cf5212c9
Cyrilvallez adjust
e28d6638
Cyrilvallez Update masking_utils.py
53e9f472
Cyrilvallez test
8e2bdd11
Cyrilvallez fix imports
558c47e0
Cyrilvallez Update modeling_gemma2.py
df49780a
Cyrilvallez other test models
a87f7ddb
Cyrilvallez Update modeling_llama4.py
8426b341
Cyrilvallez Update masking_utils.py
413d446f
Cyrilvallez improve
7f0f9898
Cyrilvallez simplify
3ed17a2c
Cyrilvallez Update masking_utils.py
f23236d7
Cyrilvallez typos
0ffff1da
Cyrilvallez typo
09d32dfd
Cyrilvallez fix
e20ebabc
Cyrilvallez Update masking_utils.py
d273325e
Cyrilvallez default DynamicCache
5ae049cd
Cyrilvallez remove default cache
326bacf3
Cyrilvallez simplify
d58eaab7
Cyrilvallez Update masking_utils.py
02a91802
Cyrilvallez Update masking_utils.py
d67de199
Cyrilvallez Update masking_utils.py
3831ccc8
Cyrilvallez Update masking_utils.py
4b54f187
Cyrilvallez simplify
6edf1166
Cyrilvallez Update masking_utils.py
18614a59
Cyrilvallez Update masking_utils.py
bd931a07
Cyrilvallez Update masking_utils.py
93f8d82c
Cyrilvallez export
711ab9b1
Cyrilvallez Update executorch.py
58f198e9
Cyrilvallez Update executorch.py
9c69ae5e
Cyrilvallez Update flex_attention.py
4e405160
Cyrilvallez Update executorch.py
6a28a34a
Cyrilvallez upstream to modular gemma 1 & 2
c70bf3c6
Cyrilvallez Update modular_mistral.py
3a972d47
Cyrilvallez switch names
7ca132df
Cyrilvallez use dict
34a55c51
Cyrilvallez put it in the Layer directly
5c89d722
Cyrilvallez update copy model source for mask functions
e6891b63
Cyrilvallez apply so many modular (hopefully 1 shot)
ac021700
Cyrilvallez use explicite dicts for make style happy
59e11ab6
Cyrilvallez protect import
27041e08
Cyrilvallez check docstring
0cf18e21
Cyrilvallez better default in hybrid caches
47158dfd
Cyrilvallez qwens
022c4a9b
Cyrilvallez Update modular_qwen2.py
94896dc9
Cyrilvallez simplify core logic!
9bbe1cbd
Cyrilvallez Update executorch.py
0844a491
Cyrilvallez qwen3 moe
dbbecde5
Cyrilvallez Update masking_utils.py
a3502631
Cyrilvallez Update masking_utils.py
09b01489
Cyrilvallez simplify a lot sdpa causal skip
fcd21a40
Cyrilvallez Update masking_utils.py
8cb637f7
Cyrilvallez post-rebase
481f0862
Cyrilvallez gemma3 finally
91c87f8b
Cyrilvallez style
9bda8648
Cyrilvallez check it before
d24309f2
Cyrilvallez gemma3
8e153a14
Cyrilvallez More general with newer torch
ebc7f9d0
Cyrilvallez align gemma3
31008ba8
Cyrilvallez Update utils.py
3c385ea6
Cyrilvallez Update utils.py
b206cd5a
Cyrilvallez Update masking_utils.py
b0850bfa
Cyrilvallez Update test_modeling_common.py
79eac774
Cyrilvallez Update flex_attention.py
29a6bc25
Cyrilvallez Update flex_attention.py
bb2dda0a
Cyrilvallez Update flex_attention.py
1b85bbbc
Cyrilvallez test
f76df19b
Cyrilvallez executorch
3ff39084
Cyrilvallez Update test_modeling_common.py
fd8a6a21
Cyrilvallez Update masking_utils.py
84db8eea
Cyrilvallez Update masking_utils.py
83ba79f9
Cyrilvallez Update masking_utils.py
acbe4be9
Cyrilvallez Update masking_utils.py
b0333def
Cyrilvallez Update executorch.py
3c483348
Cyrilvallez Update test_modeling_common.py
cfd06948
Cyrilvallez fix copies
01810426
Cyrilvallez device
ad5fb366
Cyrilvallez sdpa can be used without mask -> pass the torchscript tests in this case
b477c1ef
Cyrilvallez Use enum for check
3b71b7bc
Cyrilvallez revert enum and add check instead
1a05ca1c
Cyrilvallez remove broken test
2029cfa2
Cyrilvallez cohere2
28d62da9
Cyrilvallez some doc & reorganize the Interface
9d7bd3a4
Cyrilvallez Update tensor_parallel.py
343ab956
Cyrilvallez Cyrilvallez force pushed from ffdd1424 to 343ab956 1 year ago
Cyrilvallez Update tensor_parallel.py
78a21ea3
Cyrilvallez doc and dummy
4c87caa7
Cyrilvallez Update test_modeling_paligemma2.py
1f21213f
Cyrilvallez Update modeling_falcon_h1.py
e3530673
Cyrilvallez Update masking_utils.py
7979ac66
Cyrilvallez executorch patch
ba6501c5
Cyrilvallez style
269969ef
Cyrilvallez CIs
75ccf7ab
Cyrilvallez use register in executorch
7bcd55f6
ArthurZucker
ArthurZucker approved these changes on 2025-05-22
Cyrilvallez final comments!
9245fcda
Cyrilvallez Cyrilvallez merged 163138a9 into main 1 year ago
Cyrilvallez Cyrilvallez deleted the refactor-mask branch 1 year ago
BenjaminBossan
guangy10
guangy10 commented on 2025-06-02
kimishpatel
Cyrilvallez
kimishpatel
guangy10

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone