direct copy from llama work
3f74b1ff
mistral modules forward pass working
b194fb96
flax mistral forward pass with sliding window
6126bcd5
added tests
09717fd4
added layer collection approach
0e2905bf
Revert "added layer collection approach"
fb17b618
Revert "Revert "added layer collection approach""
41ed9a99
fixed attention outputs
89a0fd79
added mistral to init and auto
beca7bef
fixed import name
299c07aa
fixed layernorm weight dtype
d7ced4de
freeze initialized weights
2985a61a
make sure conversion consideres bfloat16
f39a798a
added backend
47a5311a
added docstrings
93cfc3a1
added cache
948cc06d
fixed sliding window causal mask
46d33ada
passes cache tests
6ebac9a7
passed all tests
33024510
applied make style
9307ba8a
removed commented out code
4788867a
applied fix-copies ignored other model changes
9a555316
Merge branch 'huggingface:main' into flax-mistral
c0b4429f
applied make fix-copies
a5c3fa4b
removed unused functions
e3f10784
passed generation integration test
69b223af
slow tests pass
3a484781
fixed slow tests
3ed1ab8a
changed default dtype from jax.numpy.float32 to float32 for docstring…
ebf50bb6
skip cache test for FlaxMistralForSequenceClassification since if pa…
8d4b56a2
updated checkpoint since from_pt not included
acf5a96e
applied black style
876c49a2
removed unused args
d9fdd15c
Merge branch 'main' into flax-mistral
7867646c
Merge branch 'main' into flax-mistral
bc9345a8
Applied styling and fixup
8d409005
changed checkpoint for doc back
60fdad7f
fixed rf after adding it to hf hub
dac618a1
Add dummy ckpt
71671d62
applied styling
c5693407
added tokenizer to new ckpt
b0ef5a14
fixed slice format
5e31d5da
Merge branch 'main' into flax-mistral
d376f6a8
fix init and slice
0f87db4b
changed ref for placeholder TODO
d018f6ef
Merge branch 'main' into flax-mistral
c415cd90
added copies from Llama
73acb3c6
applied styling
5d0a6792
Merge branch 'main' into flax-mistral
2d9211ed
applied fix-copies
a3ce45cc
fixed docs
4a218a38
Merge branch 'main' into flax-mistral
dd147599
Merge branch 'main' into flax-mistral
9c42f95b
update weight dtype reconversion for sharded weights
edd9cc68
removed Nullable input ids
c2950d94
Removed unnecessary output attentions in Module
471e3e4a
added embedding weight initialziation
3aaa0144
removed unused past_key_values
e33327ce
fixed deterministic
1e00d305
Fixed RMS Norm and added copied from
5bef1d24
removed input_embeds
5b2d914a
applied make style
adcac1c1
removed nullable input ids from sequence classification model
a5a6d705
added copied from GPTJ
85d282a2
added copied from Llama on FlaxMistralDecoderLayer
c1758cb4
added copied from to FlaxMistralPreTrainedModel methods
05d62d08
fix test deprecation warning
a2c28085
freeze gpt neox random_params and fix copies
ca00fabf
applied make style
0ba0feaa
fixed doc issue
535ef004
skipped docstring test to allign # copied from
faac78c8
Merge branch 'main' into flax-mistral
8c34572b
applied make style
9b028d28
removed FlaxMistralForSequenceClassification
212cf5d7
removed unused padding_idx
a1d20c8e
removed more sequence classification
432db636
removed sequence classification
3b1d8c7d
applied styling and consistency
2b11ce8d
Merge branch 'main' into flax-mistral
72ac552f
added copied from in tests
23d1289b
removed sequence classification test logic
df023d8c
Merge branch 'main' into flax-mistral
977690ea
applied styling
f794296b
Merge branch 'main' into flax-mistral
28e77c1a
applied make style
e5729775
removed freeze and fixed copies
ff103d07
undo test change
80bce8db
changed repeat_kv to tile
6281c606
fixed to key value groups
c278516d
updated copyright year
67d71a05
split casual_mask
df76af39
empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest
88e86c6e
went back to 2023 for tests_pr_documentation_tests
5caed6b5
went back to 2024
7764c12a
changed tile to repeat
501cc222
Merge branch 'main' into flax-mistral
9d46eebe
applied make style
ed4461fd
empty for retry on Wav2Vec2
ab28806a
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub