llama.cpp
llama: implement YaRN RoPE scaling
#2268
Merged

llama: implement YaRN RoPE scaling #2268

cebtenzzre merged 36 commits into ggml-org:master from cebtenzzre:ntkv2
cebtenzzre
cebtenzzre1 year ago (edited 1 year ago)πŸ‘ 9πŸš€ 5

This is an implementation of YaRN RoPE scaling. See https://github.com/jquesnelle/yarn and the paper and errata.

TODO:

  • Add new GGUF key for how much context the base model was trained on
  • Support converting the new models to GGUF
  • Add backward implementations
  • Test new LLaMA implementation
  • Finish and test Falcon implementation
cebtenzzre cebtenzzre force pushed from e78801a5 to 3b122de8 1 year ago
cebtenzzre cebtenzzre force pushed from 3b122de8 to ce59171f 1 year ago
cebtenzzre cebtenzzre force pushed from ce59171f to f3b9eae4 1 year ago
cebtenzzre cebtenzzre changed the title llama: implement NTK-By-Parts (NTKv2) llama: implement NTK-By-Parts (NTKv2) RoPE scaling 1 year ago
cebtenzzre cebtenzzre force pushed from f3b9eae4 to f30c5710 1 year ago
FNsi
FNsi1 year ago (edited 1 year ago)πŸ‘€ 1

Any guide to set para extrapolation and ntk? How do they work with previous two paras?

cebtenzzre
cebtenzzre1 year agoπŸ‘ 4

The upstream NTKv2 doesn't use --rope-freq-base, so it probably doesn't make sense to use it. It does use --rope-freq-scale, which works like linear scaling, and is supposed to be calibrated so that e.g. .25 scale actually gives you 8192 context. To use the default NTKv2, you should set --rope-ntk-factor and --rope-extrapolation-factor to 1, and set --rope-freq-scale appropriately. The lower the factors are, the less the respective scaling methods are mixed in, although I believe the graphs have been generated with both at 100% - the code automatically ramps them based on some experimentally determined thresholds.

cebtenzzre cebtenzzre force pushed from c62b01b8 to fe2413c2 1 year ago
cebtenzzre cebtenzzre marked this pull request as ready for review 1 year ago
cebtenzzre
cebtenzzre1 year ago

I would appreciate help with the following:

  • Should I try to write a backwards implementation? NTKv1 still doesn't have one, so I don't have much to base it on.
  • I don't have a Mac to test the Metal code on. If anyone sees obvious flaws or can test it locally, let me know.
  • I'm going to try to run a perplexity benchmark against NTKv1 and linear scaling, but I don't know if my current hardware is up to the task.
ggerganov
ggerganov commented on 2023-07-22
ggerganov1 year ago

Rename everywhere extrapolation_factor to ext_factor

Conversation is marked as resolved
Show resolved
ggml.c
283#ifndef __clang__
284#pragma GCC pop_options
285#endif
286
#endif
ggerganov1 year ago

Don't think we need minf and maxf. Use the existing MIN and MAX

cebtenzzre1 year ago

Yeah, probably premature optimization on my part. I was concerned about adding so many branches to a loop that was previously quite simple, but I haven't done any actual benchmarks.

ggerganov
ggerganov1 year ago

No need for backwards implementation for now

cebtenzzre llama: implement NTK-By-Parts (NTKv2) RoPE scaling
8dec38c3
cebtenzzre CUDA implementation
6aeb46b3
cebtenzzre Metal implementation
9348aa4d
cebtenzzre cebtenzzre force pushed from 2b610014 to 9348aa4d 1 year ago
cebtenzzre
cebtenzzre1 year ago (edited 1 year ago)

Base command (with LLAMA_CUBLAS=1): ./perplexity -m llama-7b.ggmlv3.q4_0.bin -f wiki.test.raw -ngl 100 -mmq -c 8192

Perplexity results on WikiText-2:

Arguments Perplexity
Linear --rope-freq-scale .25 10.39
NTKv1 --rope-freq-scale 0.75 --rope-freq-base 57200 7.03
NTKv2 --rope-freq-scale .25 --rope-ntk-factor 1 --rope-ext-factor 1 9.24

One problem that remains is that the context size the model was originally trained on is hardcoded to 2048. I could either add a parameter for it, or wait for GGUF.

cebtenzzre
cebtenzzre1 year ago

Perplexity with NTKv2 may be worse because neither is the dynamic version, which AFAIK works better on non-finetuned models. But fine-tuned models are far superior anyway.

NTKv1 does not converge when fine-tuning, which is why NTKv2 exists. So until somebody publishes a model fine-tuned with NTKv2β€”maybe LLongMAv2 will be released after jquesnelle publishes the paper based on scaled-ropeβ€”the existing LLongMA, which uses regular linear interpolation (just like SuperHOT), is the state-of-the-art for long contexts.

cebtenzzre
cebtenzzre1 year ago (edited 1 year ago)πŸŽ‰ 8πŸ‘€ 6

The paper has been released. The resulting method is called YaRN. Apparently the models that use this technique are good to about 120k tokens of context.
Screenshot from 2023-08-31 16-53-18

More work will definitely be needed to use these models with llama.cpp.

cebtenzzre implement new YaRN algorithm
a30ae209
cebtenzzre cebtenzzre changed the title llama: implement NTK-By-Parts (NTKv2) RoPE scaling llama: implement YaRN RoPE scaling 1 year ago
cebtenzzre
cebtenzzre1 year agoπŸ‘€ 1

There are NaNs getting in somewhere:

llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   72.03 MB
main: ggml.c:12228: ggml_compute_forward_soft_max_f32: Assertion `!isnan(sp[i])' failed.
bloc97
bloc971 year agoπŸ‘ 1

Thank you for the llamacpp implementation of YaRN!

I'm just letting you know that

constant float max_pos_emb = 2048;

should be changed to 4096 for llama 2 models when using YaRN (default was 2048 because we did the most tests with llama 1 models)
This value should probably be saved inside of the model configs and be loaded on inference...

cebtenzzre
cebtenzzre1 year agoπŸ‘ 1

should be changed to 4096 for llama 2 models

Thanks for reminding me. I originally made this PR before GGUF was finished, so I hardcoded it in the meantime. I believe I can now use the value of llama.context_length for this purpose.

KerfuffleV2
KerfuffleV21 year ago

Would it be worth testing this with non-YaRN fine-tuned models? If so, any suggested settings? I can test it with ROCM.

Green-Sky
Green-Sky1 year ago (edited 1 year ago)πŸ‘ 2

Thank you for the llamacpp implementation of YaRN!

I'm just letting you know that

constant float max_pos_emb = 2048;

should be changed to 4096 for llama 2 models when using YaRN (default was 2048 because we did the most tests with llama 1 models) This value should probably be saved inside of the model configs and be loaded on inference...

this needs to be a new GGUF kv, something like "rope_yarn_orig_ctx"

Thanks for reminding me. I originally made this PR before GGUF was finished, so I hardcoded it in the meantime. I believe I can now use the value of llama.context_length for this purpose.

llama.context_length should be the size of the finetune. eg 128Ki

cebtenzzre cebtenzzre marked this pull request as draft 1 year ago
bloc97
bloc971 year ago (edited 1 year ago)πŸ‘ 7

this needs to be a new GGUF kv, something like "rope_yarn_orig_ctx"

Exactly, after finetuning a model with YaRN, we have to keep track of two values, one being the original context length (2048 for LLaMA or 4096 for Llama 2), and also the final context length (which can be calculated by multipling the original ctx length by the scale factor, eg. 4096 x 32 = 128Ki)

In this case, the constant constant float max_pos_emb = 2048; used in the equations must be equal to the original context size, not the final context size.

cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
b5ced4fb
cebtenzzre ggml : increase GGML_MAX_OP_PARAMS
826269ad
cebtenzzre YaRN : avoid NaN if unused betas are zero
cf731d56
cebtenzzre YaRN : fix missing parameter in CUDA impl
dcb058ce
cebtenzzre convert : reduce unnecessary variables in Params
281b26e6
cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
a06c7292
cebtenzzre llama : simplify use of context params
dc26a0dd
cebtenzzre llama : store YaRN parameters in GGUF
904d4edf
cebtenzzre fix convert scripts
56abb9a4
cebtenzzre cebtenzzre force pushed from f3c213a8 to 56abb9a4 1 year ago
cebtenzzre llama : fix C compatibility
43eaf06a
cebtenzzre don't hardcode max_pos_emb
fe788c45
cebtenzzre cebtenzzre marked this pull request as ready for review 1 year ago
Green-Sky
Green-Sky commented on 2023-09-21
Green-Sky1 year ago

downloading the 7b 128k model rn
will test later

Conversation is marked as resolved
Show resolved
common/common.cpp
198 break;
199 }
200 std::string value(argv[i]);
201
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
Green-Sky1 year ago

does "none" make sense? isnt it just "linear" with a factor of 1.0 ?

cebtenzzre1 year ago (edited 1 year ago)

I thought it would make sense to provide a simple way to disable the RoPE scaling if the GGUF has it enabled. In general, I'm trying to make the GGUF keys and command line options correlate 1:1 with the HF config.json, for simplicity.

Green-Sky1 year ago

alright, sounds reasonable.

Conversation is marked as resolved
Show resolved
common/common.cpp
649683 printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
650 printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
684 printf(" --rope-scaling {none,linear,yarn}\n");
685 printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
686
printf(" --rope-scale N RoPE context scaling factor, inverse of --rope-freq-scale\n");
Green-Sky1 year ago

I came across this a while back, but just saying it is the "inverse" of another option, which also says it is the "inverse" of this option gives you effectively no information.
Something like "larger than 1.0 to elongate rope for larger context sizes" or something at one of the 2 options would be gold.

common/common.cpp
689 printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor);
690 printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor);
691 printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast);
692
printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow);
Green-Sky1 year ago

depending on how close the the paper the names of the parameters are, we might want to add a readme with more details.

cebtenzzre1 year ago (edited 1 year ago)

I don't believe the extrapolation mix factor is described in the paper because it describes an implementation that assumes it is always $1$. attn_factor is not described in the paper, but it multiplies the value described as $\sqrt t$. The slow and fast beta are known as $\alpha$ and $\beta$ in the paper, respectively.

Conversation is marked as resolved
Show resolved
convert.py
234 n_head = (n_head := config["num_attention_heads"]),
235 n_head_kv = config.get("num_key_value_heads", n_head),
236 f_norm_eps = config["rms_norm_eps"],
237
f_rope_freq_base = config.get("rope_theta"),
Green-Sky1 year ago

why the different syntax for different values? (i am not a real python dev)

cebtenzzre1 year agoπŸ‘ 1

dict.get will return None if the value is not found, or the specified fallback value. The square brackets will throw KeyError instead.

Conversation is marked as resolved
Show resolved
ggml.c
1262012650
1262112651// ggml_compute_forward_rope
1262212652
12653
static inline float rope_yarn_ramp(const float low, const float high, const int i0) {
Green-Sky1 year ago

the only thing inline does here is suppressing the warning, if the function is unused.
not sure that is desired...

Green-Sky
Green-Sky commented on 2023-09-21
Conversation is marked as resolved
Show resolved
convert.py
213 elif typ == "yarn":
214 rope_scaling_type = gguf.RopeScalingType.YARN
215 n_orig_ctx = rope_scaling['original_max_position_embeddings']
216
rope_finetuned = rope_scaling['finetuned']
Green-Sky1 year ago
Green-Sky1 year ago

assume true if original_max_position_embeddings exists?

cebtenzzre1 year agoπŸ‘€ 1

But in theory, someone could create a non-finetuned model that uses some variant of RoPE scaling by default. Maybe instead of "no" the status should be "unknown"?

Green-Sky1 year ago

non-finetuned model that uses some variant of RoPE scaling by default.

but would then there be an "original" ctx size ?

cebtenzzre1 year ago

The original context size would be whatever the model usually uses, and the new context size would be the maximum it chooses to support via YaRN. Finetuning is not mandatory, YaRN is the best context extension method either way.

Green-Sky1 year ago

ok, then why do we look at the value anyway?

cebtenzzre1 year ago

I think it's useful information for the user (it's printed when the model is loaded), since finetuned models perform better, but a model doesn't have to be finetuned to specify YaRN.

Green-Sky1 year ago

ah, yea sure, i thought it was used for a calculation.

cebtenzzre address review comments
e0b120c3
cebtenzzre
cebtenzzre commented on 2023-09-21
gguf-py/gguf/gguf.py
55KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
56KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
57KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type"
58
KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor"
cebtenzzre1 year ago (edited 1 year ago)

The removal of scale_linear is a breaking change. I suppose I should at least implement backwards compatibility (edit: done). Should there be a deprecation notice?

Green-Sky1 year ago

did you request a change in the spec in the ggml repo?

cebtenzzre restore backwards compatiblity with *.rope.scale_linear
19bb74e7
cebtenzzre better option descriptions in help
4d5fe734
cebtenzzre gguf : store scaling type as a string instead of an int
74664157
cebtenzzre improve printing of YaRN parameters
4f4e9480
cebtenzzre allow forcing ext_factor to zero if scaling type is YaRN
5d7a3a5c
cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
9bd050f1
cebtenzzre
cebtenzzre1 year agoπŸ‘€ 1

How do we move forward with this? philpax never replied. The GGUF spec isn't technically official yet since it was never merged, right?

ggerganov
ggerganov1 year ago

We should keep this PR open for now. I don't see a huge benefit of merging this method atm, since I believe it is not adopted by popular foundation models, only used for fine-tuning (correct me if I'm wrong).

ggerganov ggerganov added demo
cebtenzzre
cebtenzzre1 year agoπŸ‘ 2

I don't see a huge benefit of merging this method atm

YaRN also claims to be the state-of-the-art for context scaling without finetuning. I can put together some updated perplexity numbers.

cebtenzzre fix rope_cuda parameter order
babf0e0c
cebtenzzre default n_yarn_orig_ctx to n_ctx_train
0050e1ec
cebtenzzre fix uninitialized cparams
09c31027
cebtenzzre cebtenzzre force pushed from cd9d17c5 to 09c31027 1 year ago
cebtenzzre make printed param formatting more consistent
57c3442e
cebtenzzre
cebtenzzre1 year ago (edited 1 year ago)πŸ‘€ 1

This is how it currently performs on WikiText-2 with 16K context and a 7B LLaMA-v1 model:

Method Arguments Perplexity
Linear --rope-freq-scale .125 116.9527 +/- 0.83609
Linear + NTK --rope-freq-scale .375 --rope-freq-base 57200 35.0422 +/- 0.23598
YaRN --rope-freq-scale .125 --rope-scaling yarn 86.4788 +/- 0.67614

@bloc97 Does this seem right to you?

cebtenzzre fix missing import
a20b3e6c
bloc97
bloc971 year ago (edited 1 year ago)πŸ‘€ 2

This is how it currently performs on WikiText-2 with 16K context and a 7B LLaMA-v1 model:

Method Arguments Perplexity
Linear --rope-freq-scale .125 116.9527 +/- 0.83609
NTK --rope-freq-scale .375 --rope-freq-base 57200 35.0422 +/- 0.23598
YaRN --rope-freq-scale .125 --rope-scaling yarn 86.4788 +/- 0.67614
@bloc97 Does this seem right to you?

I am not so sure. All my tests with NTK-aware scaling was based on the scale factor (alpha), I have never tested changing freq-scale and freq-base at the same time. In equal hyperparameters scenario, YaRN should outperform everything else by a significant margin.

YaRN without finetuning at 2x scaling has almost zero PPL degradation. This means that for example, if you give Llama 2 a prompt of 4k and obtain a PPL of 4.21, then you use YaRN scaling on the model by 2x (effective context of 8k) and give it the same 4k prompt, the PPL should be like 4.23 or something. If you give it a 8k prompt, PPL would decrease to lets say 3.91.
In my tests, all other interpolation methods will have significant PPL degradations at all scaling factors.

cebtenzzre
cebtenzzre1 year ago (edited 1 year ago)πŸ‘€ 1

WikiText-2 with 6144 context and a 7B LLaMA-v1 model, trying pure NTK without linear scaling:

Method Arguments Perplexity
Linear --rope-freq-scale .333 7.2696 +/- 0.03999
NTK --rope-freq-base 60000 6.1330 +/- 0.03285
YaRN --rope-freq-scale .333 --rope-scaling yarn 6.1653 +/- 0.03305

Still not better than NTK, but acceptable I guess. My implementation is probably correct.

bloc97
bloc971 year ago (edited 1 year ago)πŸ‘€ 1

WikiText-2 with 16K context and a 7B LLaMA-v1 model, trying pure NTK without linear scaling:

Method Arguments Perplexity
Linear --rope-freq-scale .333 7.2696 +/- 0.03999
NTK --rope-freq-base 60000 6.1330 +/- 0.03285
YaRN --rope-freq-scale .333 --rope-scaling yarn 6.1653 +/- 0.03305
Still not better than NTK, but acceptable I guess. My implementation is probably correct.

It's hard for me to compare against these numbers, as this is 3x context extension right? That means 2k * 3 = 6k extension. But the tests are on 16k context without sliding windows right? Meanwhile all of our tests were done using sliding windows (from Press et al. 2022) such that the sum was performed on the minimal PPL point for all models and all hyperparameters. The two methods are measuring different things (even if the PPL metric is the same).

cebtenzzre
cebtenzzre1 year agoπŸ‘€ 1

But the tests are on 16k context without sliding windows

Sorry, I made a copy-paste error. This is regular 6144-context, non-sliding-window perplexity. Sliding window perplexity is implemented here but AFAIK is currently very slow.

bloc97
bloc971 year agoπŸ‘ 1πŸ‘€ 3

No worries, my point was that as long as the implemented version of YaRN at x2 scaling doesn't negatively impact PPL (on a non-finetuned model), and successfully extends context to 2x the original length, its implemented correctly.

For example this is Mistral 7b with YaRN at x2 scaling without any finetuning. (sorry its the only model I had at hand when writing this), but you'll just have to trust me that Llama 2 has the exact same behaviour (but is 4k base model instead of 8k)... Note that sliding window attention was disabled for this test.

image
As you can see both lines are literally overlapping (maybe like +0.01PPL for YaRN when context is under 7k, but context is extended for free)

cebtenzzre
cebtenzzre1 year agoπŸ‘ 2

Here is the 16K (8x context extension) test again, with another row added to the table:

Method Arguments Perplexity
Linear --rope-freq-scale .125 116.9527 +/- 0.83609
Linear + NTK --rope-freq-scale .375 --rope-freq-base 57200 35.0422 +/- 0.23598
YaRN --rope-freq-scale .125 --rope-scaling yarn 86.4788 +/- 0.67614
YaRN + NTK --rope-freq-scale .375 --rope-freq-base 57200 --rope-scaling yarn 25.7769 +/- 0.17455

@ggerganov Without finetuning, YaRN behaves like linear scaling, but with better perplexity scores, especially with longer contexts. I think this is more than just a demo, I would really like to have this in master.

cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
9ef91b13
ggerganov
ggerganov1 year ago

@cebtenzzre The PPL values (> 25.0) are too huge compared to the ~5-6 PPL at 2k context - I don't think this has any practical value if the numbers are correct. At 3x context scaling (6144 ctx size) there seems to be no benefit compared to existing NTK implementation, or the proposed implementation is not correct. So far I don't see a compelling reason to merge this change.

cebtenzzre
cebtenzzre1 year agoπŸ‘ 3

I don't see a significant perplexity benefit using YaRN this way with 4-5x context extension, and perplexity does start to get high outside of that range. So I guess this doesn't make sense to merge until there are more finetuned models that use it.

model context freq base freq scale linear ppl YaRN ppl improvement
LLaMA-2 7B 16384 57200 0.75 6.1875 +/- 0.03310 6.1324 +/- 0.03275 1.009
LLaMA-1 7B 8192 57200 0.75 7.1863 +/- 0.03970 7.0246 +/- 0.03864 1.023
LLaMA-1 7B 8192 57200 0.60 8.7159 +/- 0.04907 8.1126 +/- 0.04524 1.074
jquesnelle
jquesnelle1 year ago (edited 1 year ago)πŸš€ 4

There is a slight bug in this implementation that caused part of the YaRN scaling not to be applied (see cebtenzzre#1). When the fix is applied, the PPL improvements get a lot better

model context freq base freq scale linear ppl YaRN ppl improvement
Yarn-Llama-2-7B-64K 16384 57200 0.75 - 5.4893 +/- 0.02914
LLaMA-2 7B 16384 57200 0.75 6.1875 +/- 0.03310 5.8984 +/- 0.03145 1.049
LLaMA-1 7B 8192 57200 0.75 7.1863 +/- 0.03970 6.5401 +/- 0.03591 1.098
LLaMA-1 7B 8192 57200 0.60 8.7159 +/- 0.04907 7.0159 +/- 0.03899 1.242

I would note that the PR as it stands applies YaRN all the time. I think it will need to be adjusted so that the GGML code selects the appropriate scaling type based on the GGUF rope.scaling.type key. (edit: very simple fix I believe, just need to move calculation of mscale into ext_factor conditional above)

bloc97
bloc971 year ago (edited 1 year ago)πŸ‘ 1

Just wondering how we should handle changing freq_base and freq_scale in the context of YaRN... YaRN by itself already finds the optimal freq_base and freq_scale for each dimension individually, in some sense, it can be seen as an automatic and adaptive version of Linear + NTK interpolation applied to each and every RoPE dimension slightly differently, plus a mscale attention correction factor that further improves PPL.

IMHO when YaRN is enabled, both freq_base and freq_scale should be disabled as any changes will result in an inferior PPL compared to the default values.

To put it simply, all YaRN needs is the original model context length, and the target context extension (which the ratio $s$ can be computed by dividing the two). The alpha, beta and mscale hyperparameters should be determined in advance for every model in a case by case basis and be hidden from an end-user. The default alpha, beta and mscale in YaRN is currently only optimal for all LLaMA and Llama 2 family of models.

Also, there might be some more other subtle implementation differences between huggingface transformers and llama.cpp, as those PPL improvements seen above are fairly minimal. YaRN should have significatively better PPL compared to other methods in non-finetuned scenarios.

The advantages of YaRN is threefold:

  • Fine-tuned YaRN models have much better PPL compared to fine-tuned Linear and NTK models in all context sizes.
  • Non-finetuned YaRN models' PPL is closer to a fine-tuned YaRN model compared to other methods.
  • At small scales, (<3x extension), non-finetuned YaRN is almost equivalent to fine-tuned YaRN (diff of like +0.01PPL), so in some sense, fine-tuning is not necessary, as all other methods are inferior, no matter if fine-tuned or not. Using Dynamic-YaRN removes all penalities, and a non-finetuned model can enjoy 3x context extension for free, without any PPL degradation.
bloc97
bloc971 year ago (edited 1 year ago)

Here's some benchmarks I got from huggingface transformers on the GovReport dataset. There might be still a small bug in this implementation somewhere as I'm getting better PPL improvements across the board on the reference implementation.

Fixed s=4 YaRN:

model context freq base freq scale scale-base ppl YaRN ppl (s=4) improvement
LLaMA-1 7B 8192 10000 0.25 6.4290 4.3511 1.477
LLaMA-1 7B 8192 74000 1.00 4.8251 4.3511 1.109
LLaMA-1 7B 8192 57200 0.75 4.9294 4.3511 1.133
LLaMA-1 7B 8192 57200 0.60 5.5975 4.3511 1.286

Dynamic-YaRN:

model context freq base freq scale scale-base ppl Dynamic YaRN ppl improvement
LLaMA-1 7B 8192 10000 0.25 6.4290 4.1972 1.532
LLaMA-1 7B 8192 74000 1.00 4.8251 4.1972 1.150
LLaMA-1 7B 8192 57200 0.75 4.9294 4.1972 1.174
LLaMA-1 7B 8192 57200 0.60 5.5975 4.1972 1.334

image

bloc97
bloc971 year ago

At s=8, there's no contest... (Again, note these are reference benchmarks using huggingface transformers.)
image

Fixed s=8 YaRN:

model context freq base freq scale scale-base ppl YaRN ppl (s=8) improvement
LLaMA-1 7B 16384 10000 0.125 45.066 4.5578 9.888
LLaMA-1 7B 16384 200000 1.00 7.7015 4.5578 1.690
LLaMA-1 7B 16384 120000 0.75 8.5140 4.5578 1.868
LLaMA-1 7B 16384 120000 0.60 10.523 4.5578 2.309

Dynamic-YaRN:

model context freq base freq scale scale-base ppl Dynamic YaRN ppl improvement
LLaMA-1 7B 16384 10000 0.125 45.066 4.3574 10.34
LLaMA-1 7B 16384 200000 1.00 7.7015 4.3574 1.767
LLaMA-1 7B 16384 120000 0.75 8.5140 4.3574 1.954
LLaMA-1 7B 16384 120000 0.60 10.523 4.3574 2.415
jquesnelle Fix YaRN inverted scaling and add "rope.scaling.type" to GGUF (#1)
9ae10b3a
jquesnelle
jquesnelle1 year agoπŸ‘ 2❀ 1πŸš€ 4

Okay, I think I have some definitive answers now! There was a second bug in the implementation, but it looks like we have it all squared away now (cebtenzzre#2 awaiting merge).

There are two scenarios to consider: non-finetuned (extending any model) and finetuned (using a model trained with YaRN). With the updated code we're getting quite good results under both scenarios. All evals are done with Q4_0.

For reproducibility, here is an example command line (from "Finetuned" below): ./perplexity -m yarn-llama-2-7b-64k.Q4_0.gguf -f wiki.test.raw -ngl 100 -c 16384 --rope-scaling yarn --rope-freq-scale 0.0625 --yarn-orig-ctx 4096

Non-finetuned

Model: LLaMA 2 7B

Commit: jquesnelle@f51eed1

ctx base scale linear ppl YaRN ppl improvement
16384 57200 0.6000 7.2699 +/- 0.03976  6.0714 +/- 0.03235 1.197
16384 57200 0.7500 6.1870 +/- 0.03310  5.8867 +/- 0.03129 1.051
16384 57200 1.0000 9.6980 +/- 0.06093  9.6980 +/- 0.06093 1.000
16384 10000 0.1250 54.4799 +/- 0.36304  6.5957 +/- 0.03584 8.300

Commit: f3b25e4

ctx base scale linear ppl YaRN ppl
16384 57200 1.0000 9.6980 +/- 0.06093  n/a

Finetuned

Model: YaRN LLaMA 2 7B 64K

Commit: jquesnelle@f51eed1

ctx base scale YaRN ppl
16384 10000 0.0625  5.1497 +/- 0.02717

In both scenarios, YaRN performs better than regular linear scaling. We additionally see that the YaRN code is equivalent to linear when the scale is 1 with or without a base change. Moreover, the perplexities match the existing code on master, meaning the changes are backward-compatible. Given this, I think it's a good candidate to merge πŸ™‚

jquesnelle fix YaRN ramp, make mscale conditional, add --yarn-orig-ctx (#2)
14cf93b1
cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
237f1e79
cebtenzzre
cebtenzzre1 year agoπŸ‘ 1πŸ‘€ 1

IMHO when YaRN is enabled, both freq_base and freq_scale should be disabled as any changes will result in an inferior PPL compared to the default values.

@bloc97 I have --rope-freq-scale set up to configure the YaRN scale factor when "--rope-scaling yarn" is passed, which seemed simpler than making it separately configurable but mutually exclusive. And based on jquesnelle's results above, the perplexity with YaRN is 12% better with freq_base=57200 (5.8867) than with freq_base=10000 (6.5957), even after fixing the bugs in the implementation. So I'm not inclined to disable the --rope-freq-base option.

cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
bc8395d5
bloc97
bloc971 year ago (edited 1 year ago)πŸ‘ 1πŸ‘€ 1

@bloc97 I have --rope-freq-scale set up to configure the YaRN scale factor when "--rope-scaling yarn" is passed, which seemed simpler than making it separately configurable but mutually exclusive. And based on jquesnelle's results above, the perplexity with YaRN is 12% better with freq_base=57200 (5.8867) than with freq_base=10000 (6.5957), even after fixing the bugs in the implementation. So I'm not inclined to disable the --rope-freq-base option.

I'll do a few more tests using the reference implementation in huggingface (to figure out whether that's a bug or actual behaviour), but I think since the finetuned YaRN models now work, we can go ahead with the merge and look at the remaining stuff as we go...

cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
4d5ed834
cebtenzzre
cebtenzzre1 year ago

@ggerganov What are your thoughts on the current state of this PR?

ggerganov
ggerganov1 year agoπŸ‘ 2

The numbers look better, I think we can merge. Let me take one more look again tomorrow and will proceed.

Perplexity aside, do we have studies that show when using YARN the model is still able to recover information from the entire context? I'm thinking, for example with context shift (a.k.a. StreamingLLM, a.k.a. old context swap) we also get good perplexity for very long contexts, but the problem is that the model "forgets" stuff that goes out of scope. Does YARN solve this?

cebtenzzre
cebtenzzre1 year ago

do we have studies that show when using YARN the model is still able to recover information from the entire context

There is a passkey test here but I don't know if the results of it were published anywhere.

bloc97
bloc971 year ago (edited 1 year ago)πŸ‘ 2

@ggerganov @cebtenzzre The 128k YaRN FTed models have 99.4% random passkey retrieval accuracy across their entire context size. This can be tested using the file that @cebtenzzre linked. Non-FTed YaRN also has relatively high passkey accuracies but I don't have the results on hand.

ggerganov
ggerganov approved these changes on 2023-10-28
ggerganov1 year agoπŸš€ 3

The implementation is very well done, so lets merge it. I'm worried we are adding quite a lot of extra code for this feature, but hopefully it would be useful.

Btw, I continue to be skeptical about the usefulness of YARN. I expect with extending the context size to see PPL drop and not remain at the same level as the original context.

The passkey test passing is OK, but I think that a sliding-window processing using the original context size would also pass it. Would like to be proven wrong. But in any case, adding a passkey test example to llama.cpp would be useful in general.

ggerganov
ggerganov1 year ago

@cebtenzzre Are you planning to finish and test the Falcon implementation before merging?

bloc97
bloc971 year ago (edited 1 year ago)

Btw, I continue to be skeptical about the usefulness of YARN. I expect with extending the context size to see PPL drop and not remain at the same level as the original context.

PPL curves are always decreasing for finetuned YaRN models. Llama-2 YaRN models are currently the only models that has the property of the PPL always decreasing up to 128k context. Codellama has similar properties but much worse PPL due to it using older NTK-aware scaling and other confounding factors (it being a code focused model).
Also, the plots shown above are for non-FTed LLaMA 1. In this scenario, YaRN is also a huge step up from previous methods.

The passkey test passing is OK, but I think that a sliding-window processing using the original context size would also pass it. Would like to be proven wrong. But in any case, adding a passkey test example to llama.cpp would be useful in general.

Sliding window (either on the prompt or on the attention logits) won't work for passkey at all unless you know exactly where the passkey was in advance in the prompt (for example using an oracle), which defeats the purpose of using a long context model in the first place. In any case, you can put three passkeys (one in the beginning, one in the middle and one in the end), and ask the model to retrieve all three at the same time.

cebtenzzre
cebtenzzre1 year ago

Are you planning to finish and test the Falcon implementation before merging?

Right now it's deactivated for Falcon. I'm looking into it, but I don't really understand the Metal implementation of GPT-NeoX RoPE (#3024 (comment)) - so I'm not 100% sure what to put in place of i0 (number of rotations) for each backend.

ggerganov
ggerganov1 year agoπŸ‘ 1πŸ‘€ 1

@bloc97 By sliding window I mean the strategy where the KV cache is shifted to evict old tokens and free space for new tokens. This is implemented in the main example of llama.cpp for infinite text generation:

https://github.com/ggerganov/llama.cpp/blob/3b778a4af967b2c576564a564521c2f0fe5704ed/examples/main/main.cpp#L490-L501

This strategy retains past information beyond the limit of the context size because new tokens attend to the KV of old tokens which have "seen" the evicted tokens. I just don't know to what extend this information is retained, but intuitively, the passkey test has so small entropy (because of the repeated text) that I won't be surprised the sliding window on the KV cache to pass it. It's something that can be easily tested.

bloc97
bloc971 year ago

This strategy retains past information beyond the limit of the context size because new tokens attend to the KV of old tokens which have "seen" the evicted tokens. I just don't know to what extend this information is retained, but intuitively, the passkey test has so small entropy (because of the repeated text) that I won't be surprised the sliding window on the KV cache to pass it. It's something that can be easily tested.

We have tested a similar strategy like this using the SWA implementation used in Mistral 7b, and the passkey results drops to 0% after evicting the tokens in the kv-cache, while PPL stays stable. It seems that the attention algorithm does not compress past information into the new tokens (because it was never trained to do so)...

ggerganov
ggerganov1 year agoπŸ‘ 1

Yup, I did the test as well just now and indeed it fails: #3856
So my expectation was not correct.

jquesnelle fix loading rope.scaling.original_context_length from GGUF (#3)
9fc82382
ggerganov
ggerganov1 year agoπŸš€ 3

With the recent refactoring, I've created quite some conflicts (sorry about that).
After we resolve we can merge straight away

cebtenzzre implement YaRN for GPT-NeoX RoPE
15f26efd
cebtenzzre Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
081f7381
cebtenzzre cebtenzzre merged 898aeca9 into master 1 year ago
cebtenzzre cebtenzzre deleted the ntkv2 branch 1 year ago
slaren
slaren1 year ago

llama_context_default_params() needs to be updated, the comments don't match the initializers and there are missing initializers.

redthing1
redthing11 year ago

Based 🫑

cebtenzzre
cebtenzzre1 year ago

@ggerganov Could you go to Settings > Moderation options > Interaction limits and block the above user (Dezzj) from commenting? The spam is getting annoying.
https://docs.github.com/en/communities/moderating-comments-and-conversations/limiting-interactions-in-your-repository

LostRuins
LostRuins1 year ago (edited 1 year ago)

@cebtenzzre seems like this commit broke CI? The last 3 builds have failures on the v100.

image

redthing1
redthing11 year ago

I can confirm that this is working on my mac :)

cebtenzzre cebtenzzre restored the head branch 1 year ago
cebtenzzre
cebtenzzre1 year ago

seems like this commit broke CI? The last 3 builds have failures on the v100.

Erm... yeah, I can reproduce this locally. I think I screwed up something in 15f26ef. I thought I tested that change, but I must not have had CUDA enabled. I thought the PR's CI would have caught this...

cebtenzzre cebtenzzre deleted the ntkv2 branch 1 year ago
IridiumMaster
IridiumMaster1 year ago

this commit appears to have caused one of my phind codellama 34b 16k q5 model to emit gibberish on a CUDA machine, but not on my mac. I can provide a detailed reproduction if you want, or can wait until after your fix.

ggerganov
ggerganov1 year ago

@IridiumMaster Can you confirm latest master is stable?

@cebtenzzre There is no per-user option, or at least I cannot find it

LostRuins
LostRuins1 year ago

I am integrating this commit and I don't know how to disable YaRN across all other modes.

Specifically, there are now 5 new arguments required for ggml_rope_custom_inplace. Zeroing all these new values out results in incoherent output. Only by using the values of 0 NAN 1 32 1 I am able to get coherent output again - but I have no idea if the behavior matches how RoPE behaved previously. Does that mean YaRN RoPE scaling is disabled? What values should I use to get the exact same behavior in other models as before this commit was merged? Or must the "disable" state match these values?

IridiumMaster
IridiumMaster1 year ago

@IridiumMaster Can you confirm latest master is stable?

@cebtenzzre There is no per-user option, or at least I cannot find it

Hi, the latest master is stable for me. I do not see the gibberish that I did before. In case you're wanting to reproduce in future, here are some steps:

  1. Grab this model: https://huggingface.co/TheBloke/Phind-CodeLlama-34B-v2-GGUF/resolve/main/phind-codellama-34b-v2.Q5_K_M.gguf
  2. Feed this to the model via the server api completion endpoint with a max of 100 predictions:
### System Prompt 
### User Message
What is the capital of Nebraska?
### Assistant
  1. If the model is working correctly, 99/100 times the correct answer 'Lincoln' will be a word within the text produced. In the case of the errant commit, a bunch of UTF-8 characters were produced instead.
cebtenzzre
cebtenzzre1 year agoπŸ‘ 1

I am integrating this commit and I don't know how to disable YaRN across all other modes.

Currently, attn_factor must be 1.0f, ext_factor must be 0.0f, and the rest don't matter but can be zero. Then YaRN should be fully disabled.

jxy
jxy1 year ago

After this merge every single model I tried had failed to produce meaningful words. The previous commit c43c2da works fine.

jxy
jxy1 year ago

OK. I found the culprit. I have to explicitly pass --yarn-ext-factor 0.0 to main. Otherwise it gives gibberish.

This is alright:

[1698963713] Log start
[1698963713] Cmd: ./main -m models/codellama-7b-instruct.Q8_0.gguf -n 1 --top-k 1 -p " 1+1=" --yarn-ext-factor 0.0
[1698963713] main: build = 1477 (629f917)
[1698963713] main: built with clang version 17.0.4 for arm64-apple-darwin23.1.0
[1698963713] main: seed  = 1698963713
[1698963713] main: llama backend init
[1698963713] main: load the model and apply lora adapter, if any
[1698963713] warming up the model with an empty run
[1698963713] n_ctx: 512
[1698963713] 
[1698963713] system_info: n_threads = 4 / 8 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | 
[1698963713] add_bos: 1
[1698963713] tokenize the prompt
[1698963713] prompt: " 1+1="
[1698963713] tokens: [ '':1, ' ':29871, '1':29896, '+':29974, '1':29896, '=':29922 ]
[1698963713] recalculate the cached logits (check): embd_inp.empty() false, n_matching_session_tokens 0, embd_inp.size() 6, session_tokens.size() 0, embd_inp.size() 6
[1698963713] inp_pfx: [ '':1, '':13, '':13, '##':2277, '#':29937, ' Inst':2799, 'ruction':4080, ':':29901, '':13, '':13 ]
[1698963713] inp_sfx: [ '':13, '':13, '##':2277, '#':29937, ' Response':13291, ':':29901, '':13, '':13 ]
[1698963713] sampling: 
        repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 1, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
[1698963713] generate: n_ctx = 512, n_batch = 512, n_predict = 1, n_keep = 0
[1698963713] 

[1698963713] embd_inp.size(): 6, n_consumed: 0
[1698963713] eval: [ '':1, ' ':29871, '1':29896, '+':29974, '1':29896, '=':29922 ]
[1698963713] n_past = 6
[1698963713] sampled token: 29906: '2'

This is bad:

[1698963773] Log start
[1698963773] Cmd: ./main -m models/codellama-7b-instruct.Q8_0.gguf -n 1 --top-k 1 -p " 1+1="
[1698963773] main: build = 1477 (629f917)
[1698963773] main: built with clang version 17.0.4 for arm64-apple-darwin23.1.0
[1698963773] main: seed  = 1698963773
[1698963773] main: llama backend init
[1698963773] main: load the model and apply lora adapter, if any
[1698963773] warming up the model with an empty run
[1698963773] n_ctx: 512
[1698963773] 
[1698963773] system_info: n_threads = 4 / 8 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | 
[1698963773] add_bos: 1
[1698963773] tokenize the prompt
[1698963773] prompt: " 1+1="
[1698963773] tokens: [ '':1, ' ':29871, '1':29896, '+':29974, '1':29896, '=':29922 ]
[1698963773] recalculate the cached logits (check): embd_inp.empty() false, n_matching_session_tokens 0, embd_inp.size() 6, session_tokens.size() 0, embd_inp.size() 6
[1698963773] inp_pfx: [ '':1, '':13, '':13, '##':2277, '#':29937, ' Inst':2799, 'ruction':4080, ':':29901, '':13, '':13 ]
[1698963773] inp_sfx: [ '':13, '':13, '##':2277, '#':29937, ' Response':13291, ':':29901, '':13, '':13 ]
[1698963773] sampling: 
	repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 1, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
[1698963773] generate: n_ctx = 512, n_batch = 512, n_predict = 1, n_keep = 0
[1698963773] 

[1698963773] embd_inp.size(): 6, n_consumed: 0
[1698963773] eval: [ '':1, ' ':29871, '1':29896, '+':29974, '1':29896, '=':29922 ]
[1698963773] n_past = 6
[1698963773] sampled token:     0: 'β–…'
jxy
jxy1 year agoπŸ‘€ 1

The issue is because this code depends on NAN and std::isnan, which unfortunately breaks when compiled with LLAMA_FAST, or -Ofast.

jxy
jxy1 year agoπŸ‘ 2

If ext_factor would never go negative,

diff --git a/common/common.h b/common/common.h
index 72a49b8..7760fb5 100644
--- a/common/common.h
+++ b/common/common.h
@@ -61,7 +61,7 @@ struct gpt_params {
     int32_t n_beams                         = 0;    // if non-zero then use beam search of given width.
     float   rope_freq_base                  = 0.0f; // RoPE base frequency
     float   rope_freq_scale                 = 0.0f; // RoPE frequency scaling factor
-    float   yarn_ext_factor                 = NAN;  // YaRN extrapolation mix factor
+    float   yarn_ext_factor                 = -1.0f;// YaRN extrapolation mix factor
     float   yarn_attn_factor                = 1.0f; // YaRN magnitude scaling factor
     float   yarn_beta_fast                  = 32.0f;// YaRN low correction dim
     float   yarn_beta_slow                  = 1.0f; // YaRN high correction dim
diff --git a/llama.cpp b/llama.cpp
index bb60044..5748c52 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -8125,7 +8125,7 @@ struct llama_context * llama_new_context_with_model(
         cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
     }
 
-    if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
+    if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
     }
 
cebtenzzre
cebtenzzre1 year ago (edited 1 year ago)

If ext_factor would never go negative,

I'd be fine with that solution. Would you like to make a PR?

edit: For some reason, I can't reproduce this on Linux with clang or gcc, or on an M2 Mac, at least on CPU.

edit 2: I can't build llama.cpp with Metal on my Mac:

c++ -I. -Icommon -D_XOPEN_SOURCE=600 -D_DARWIN_C_SOURCE -DNDEBUG -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_METAL  -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -pthread  -Ofast -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi  examples/main/main.cpp ggml.o llama.o common.o sampling.o grammar-parser.o build-info.o console.o ggml-metal.o ggml-alloc.o ggml-backend.o ggml-quants.o -o main -framework Accelerate -framework Foundation -framework Metal -framework MetalKit 
0  0x102f3b648  __assert_rtn + 72
1  0x102e63c5c  ld::Fixup::applyFixup(ld::Atom const*, ld::LayoutLinkedImage const&, unsigned char*) const + 8268
2  0x102ef67d8  ___ZN2ld16LayoutExecutable27writeContentWithoutLinkEditENSt3__14spanIhLm18446744073709551615EEEy_block_invoke + 332
3  0x102ef6a14  void mapReduce<ld::Atom const*, mach_o::Error>(std::__1::span<ld::Atom const*, 18446744073709551615ul>, unsigned long, void (unsigned long, mach_o::Error&, std::__1::span<ld::Atom const*, 18446744073709551615ul>) block_pointer, void (std::__1::span<mach_o::Error, 18446744073709551615ul>) block_pointer) + 384
4  0x102ef6594  ld::LayoutExecutable::writeContentWithoutLinkEdit(std::__1::span<unsigned char, 18446744073709551615ul>, unsigned long long) + 1180
5  0x102efc020  ld::LayoutExecutable::writeToFile(char const*) + 15248
6  0x102eae2e8  main + 9424
ld: Assertion failed: (extras.otherInstrOffset != 0 && "Kind::arm64_adrp_ldr missing extra info"), function applyFixup, file Fixup.cpp, line 793.
clang: error: linker command failed with exit code 1 (use -v to see invocation)
make: *** [main] Error 1

Seems like a bug in the XCode-provided clang 15?

KerfuffleV2
KerfuffleV21 year ago

#2268 (comment) - this seems to fix my problem. Really weird that it only has an effect when offloading that last non-repeating layer.

jxy
jxy1 year ago

@cebtenzzre thanks for pushing the pr.

Now I'm testing this https://huggingface.co/TheBloke/Yarn-Mistral-7B-64k-GGUF and I'm getting

$ ./perplexity -t 1 -ngl 1 -m models/yarn-mistral-7b-64k.Q8_0.gguf -c 512 -f ../wikitext-2-raw/wiki.test.raw 2>/dev/null
[1]24.7243,[2]31.1885,[3]36.5431,[4]41.0809,^C

so something must be wrong, as the base model has

$ ./perplexity -t 1 -ngl 1 -m models/mistral-7b-v0.1.Q8_0.gguf -c 512 -f ../wikitext-2-raw/wiki.test.raw 2>/dev/null   
[1]3.9958,[2]4.4960,[3]5.2987,[4]5.9971,^C

The gguf is recognized correctly

llm_load_print_meta: rope scaling     = yarn
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 0.125
llm_load_print_meta: n_yarn_orig_ctx  = 8192
llm_load_print_meta: rope_finetuned   = yes

and

llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 0.125
llama_new_context_with_model: kv self size  =   64.00 MB
jxy
FNsi
ggerganov
FNsi
Green-Sky
ggerganov
Dampfinchen
KerfuffleV2
bloc97
cebtenzzre cebtenzzre removed demo

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone