llama.cpp
Support for Phi-2
#4490
Merged

Support for Phi-2 #4490

ggerganov merged 17 commits into ggml-org:master from ebeyabraham:master
ebeyabraham
ebeyabraham1 year ago (edited 1 year ago)πŸ‘ 5❀ 43πŸš€ 5

This PR adds support to run Phi-2.

The implementation replaces the cross-attentions in the HF implementation with simple attention layers similar to the MLX implementation: MLX/Phi-2

Closes: #4437, #3146

Example Usage

# convert hf model to GGUF
python convert-hf-to-gguf.py phi-2

# fp-16 inference
./main -m phi-2/ggml-model-f16.gguf -p "Question: Write a python function to print the first n numbers in the fibonacci series"
eabraham-1 phi2 implementation
12cc80cb
eabraham-1 fix breaking change
e2076553
Nick-infinity
Nick-infinity1 year agoπŸ‘ 1

Got below error when converting the model
Loading model: phi-2 gguf: This GGUF file is for Little Endian only Set model parameters Set model tokenizer Traceback (most recent call last): File "convert-hf-to-gguf.py", line 1058, in <module> model_instance.set_vocab() File "convert-hf-to-gguf.py", line 52, in set_vocab self._set_vocab_gpt2() File "convert-hf-to-gguf.py", line 252, in _set_vocab_gpt2 if tokenizer.added_tokens_decoder[i].special: AttributeError: 'CodeGenTokenizerFast' object has no attribute 'added_tokens_decoder'

ggerganov
ggerganov1 year agoπŸ‘ 4

@Nick-infinity I get the same error. Using the following patch fixes it:

diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index b56be844..aed84f1b 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -249,7 +249,7 @@ class Model:
                 toktypes.append(gguf.TokenType.USER_DEFINED)
             elif reverse_vocab[i] in added_vocab:
                 tokens.append(reverse_vocab[i])
-                if tokenizer.added_tokens_decoder[i].special:
+                if hasattr(tokenizer, "added_tokens_decoder") and tokenizer.added_tokens_decoder[i].special:
                     toktypes.append(gguf.TokenType.CONTROL)
                 else:
                     toktypes.append(gguf.TokenType.USER_DEFINED)

But I am not sure if this is the correct way to proceed

FiveTechSoft
FiveTechSoft1 year ago (edited 1 year ago)

Could someone upload the resulting gguf to hf or somewhere else ?

ggerganov phi-2 : various fixes
a2a3d2c8
ggerganov
ggerganov1 year agoπŸ‘ 2

@mrgraycode Can you allow push access to this branch - I want to push some fixes

ggerganov phi-2 : use layer norm eps
aa5c881a
ggerganov py : whitespaces
7500fa2f
ggerganov llama : fix meta KV override bug
5469d82d
ggerganov convert : phi don't add BOS token
a878be4c
salykova
salykova1 year ago (edited 1 year ago)❀ 2

maybe we should rename phi2 to phi, since all phi models (1, 1.5, 2) have the same architecture? what do you think? @mrgraycode @ggerganov

FiveTechSoft
FiveTechSoft1 year ago
ggerganov
ggerganov1 year ago

@FiveTechSoft are you using a Mac? CPU and CUDA are still broken I think

salykova
salykova1 year ago (edited 1 year ago)

@FiveTechSoft are you using a Mac? CPU and CUDA are still broken I think

@ggerganov cpu works fine with f32, cuda is broken (segfault)

FiveTechSoft
FiveTechSoft1 year ago

@FiveTechSoft are you using a Mac? CPU and CUDA are still broken I think

Mac intel xeon. It runs very fast. Just wondering why the phi2 gguf is much larger than Orca gguf

salykova
salykova1 year ago (edited 1 year ago)

Could someone upload the resulting gguf to hf or somewhere else ?

https://huggingface.co/kroonen/phi-2-GGUF
@FiveTechSoft

FiveTechSoft
FiveTechSoft1 year ago

@salykovaa many thanks!

Working great and the phi-2_Q4_K_M.gguf size is nice!

ebeyabraham
ebeyabraham1 year agoπŸ‘ 2

Got below error when converting the model Loading model: phi-2 gguf: This GGUF file is for Little Endian only Set model parameters Set model tokenizer Traceback (most recent call last): File "convert-hf-to-gguf.py", line 1058, in <module> model_instance.set_vocab() File "convert-hf-to-gguf.py", line 52, in set_vocab self._set_vocab_gpt2() File "convert-hf-to-gguf.py", line 252, in _set_vocab_gpt2 if tokenizer.added_tokens_decoder[i].special: AttributeError: 'CodeGenTokenizerFast' object has no attribute 'added_tokens_decoder'

@Nick-infinity @ggerganov I am able to reproduce this error withtransformers==4.33.3 but it works fine for >=4.34.0

ggerganov
ggerganov commented on 2023-12-16
Conversation is marked as resolved
Show resolved
ggml-cuda.cu
49984998 const int ib = col / n_dims;
49994999 const int ic = col % n_dims;
50005000
5001
const int i = row*ncols + ib*n_dims + ic/2;
5001
// IMPORTANT: consider the case ncols == 80 and n_dims == 32 (phi-2)
5002
// I don't know what we are supposed to compute, because the row is not divisible by n_dims
5003
// this check matches the CPU code, but it is likely wrong as well
5004
// I can't understand the Python code, so if you know what to do here, please fix it
5005
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
5006
if (ncols % n_dims != 0 && ib == ncols/n_dims) {
5007
return;
5008
}
ggerganov1 year ago

Will need some assistance in fixing this

slaren1 year agoπŸ‘ 3

The python code looks like only applies the rope to the first n_dims columns, and leaves the rest unchanged. Is that correct?

ggerganov1 year ago

Ok, I think I understand how it works now. Will fix it

ggerganov convert : revert "added_tokens_decoder" change
0b6ffa58
ggerganov phi-2 : scale Q instead of KQ for better precision
0644c3be
ggerganov
ggerganov1 year agoπŸ‘ 3

The CPU inference with ARM_NEON was producing garbage. It turned out that the queries Q can become quite big in absolute value and get inf during F16 dot products. I changed the 1/sqrt(n_embd_head) scaling to be applied to Q instead of to KQ and this resolved the issue.

However, the CUDA code still produces garbage - not sure yet why

Also:

  • fixed the norm eps (1e-5) not being read from the meta data
  • disabled "add BOS" flag
  • fixed a bug where "add BOS" flag (and all bool flags in the GGUF meta data) was ignored

I also suspect that our NeoX RoPE computation for ncols == 80 and n_dims == 32 is wrong

Any help with this one will be appreciated, will look more into this tomorrow

ggerganov
ggerganov commented on 2023-12-16
llama.cpp
19331949 target = override->bool_value;
19341950 return true;
19351951
}
1936
return true;
1952
return false;
ggerganov1 year ago

This was the bug related to reading bool flags from the GGUF meta data

x4080
x40801 year ago

Phi-2 works great on my m2 pro, I suggest using better prompt :

Question: How to calculate 12+7*28, explain
Answer:
devic1
devic1 approved these changes on 2023-12-17
ggerganov ggml : fix NeoX rope to rotate just first n_dims
f703ca8a
ggerganov
ggerganov1 year ago (edited 1 year ago)πŸ‘ 6❀ 3

This should work now on CPU and Metal correctly:

make -j main && ./main -m models/phi-2/ggml-model-q4_0.gguf -e -p "Question: Prove that sqrt(2) is irrational.\nAnswer:" -ngl 99 -n 512 -s 1 --temp 0 --repeat-penalty 1 --no-penalize-nl
system_info: n_threads = 16 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temp 
generate: n_ctx = 512, n_batch = 512, n_predict = 512, n_keep = 0


Question: Prove that sqrt(2) is irrational.
Answer: To prove that sqrt(2) is irrational, we will use a proof by contradiction. Assume that sqrt(2) is rational, which means it can be expressed as a fraction p/q, where p and q are integers with no common factors. Then, we can square both sides of the equation to get 2 = p^2/q^2. This implies that p^2 is even, since it is divisible by 2. Therefore, p must also be even, since the square of an odd number is odd. Let p = 2k, where k is an integer. Then, we can substitute this into the equation to get 2 = (2k)^2/q^2, or 2 = 4k^2/q^2. This simplifies to q^2 = 2k^2, or q^2 is even. This means that q must also be even, since the square of an odd number is odd. However, this contradicts our assumption that p and q have no common factors, since they are both even. Therefore, our assumption that sqrt(2) is rational must be false, and sqrt(2) is irrational.
 [end of text]

llama_print_timings:        load time =     137.53 ms
llama_print_timings:      sample time =       8.09 ms /   246 runs   (    0.03 ms per token, 30400.40 tokens per second)
llama_print_timings: prompt eval time =      39.82 ms /    16 tokens (    2.49 ms per token,   401.77 tokens per second)
llama_print_timings:        eval time =    1912.37 ms /   245 runs   (    7.81 ms per token,   128.11 tokens per second)
llama_print_timings:       total time =    1992.12 ms

However, CUDA remains somewhat of a problem. I've been tracing the source of the nans and the conclusion is that we lack enough precision in the matrix multiplications in the attention.

I have found 2 solutions so far, both of which are not ideal:

The latter solution has negative performance effect on some video cards.

One option to proceed is to make #3816 a runtime option and people would be able to switch between the 2 modes via llama.h parameter.

ggerganov cuda : less diff in the rope_neox kernel
42e95258
FiveTechSoft
FiveTechSoft1 year ago

Dear Georgi,

great work, thank you so much for you great effort!

slaren
slaren1 year agoπŸ‘ 4

I tried this rope implementation with stablelm rocket 3b, and it reduces perplexity from ~28 to ~25 at chunk 100.

slaren
slaren1 year ago (edited 1 year ago)

However, CUDA remains somewhat of a problem. I've been tracing the source of the nans and the conclusion is that we lack enough precision in the matrix multiplications in the attention.

Since this only applies to the attention mat muls, we could also add an op_param to ggml_mul_mat that disallows using F16 compute. The transparent downcasting to F16 that we do currently seems a bit sketchy, but we don't really have a good way to address that until ggml supports F16 outputs.

ggerganov
ggerganov1 year ago

we could also add an op_param to ggml_mul_mat that disallows using F16 compute

Yes, this sounds like a better idea. Should I add ggml_mul_mat_f32 overload?

slaren
slaren1 year ago

Should I add ggml_mul_mat_f32 overload?

A function like ggml_mul_mat_set_f32 to add the flag afterwards to an existing tensor could also work, but I am not sure what would be better.

QwertyJack
QwertyJack1 year ago

I just tried this exciting feature, and the main binary works perfectly! However, the server generates repeated tokens forever. Weird.

ggerganov
ggerganov1 year agoπŸ‘ 2

@QwertyJack What hardware / commands are you using? There is a known issue with CUDA atm.

I'll look into finalizing this PR today and merging

QwertyJack
QwertyJack1 year ago

Pure CPU works fine. The issue happens when offloading all layers to a T4.

ggerganov
ggerganov1 year agoπŸ‘ 6πŸ˜„ 1

Ok, that's expected atm. Will be fixed before merging the PR

ggerganov Merge branch 'master' into HEAD
a8d2a6f3
ggerganov ggml : add ggml_mul_mat_set_prec
18c67bdd
ggerganov ggerganov force pushed from 494f4b29 to 18c67bdd 1 year ago
ggerganov ggerganov requested a review from slaren slaren 1 year ago
slaren
slaren commented on 2023-12-18
Conversation is marked as resolved
Show resolved
ggml-cuda.cu
8404 const float alpha_f32 = 1.0f;
8405 const float beta_f32 = 0.0f;
8406
8407
const char * alpha = (const char *) &alpha_f16;
8408
const char * beta = (const char *) &beta_f16;
slaren1 year ago
Suggested change
const char * alpha = (const char *) &alpha_f16;
const char * beta = (const char *) &beta_f16;
const void * alpha = &alpha_f16;
const void * beta = &beta_f16;
Conversation is marked as resolved
Show resolved
ggml-cuda.cu
8419 cu_compute_type = CUBLAS_COMPUTE_32F;
8420 cu_data_type = CUDA_R_32F;
8421
8422
alpha = (const char *) &alpha_f32;
8423
beta = (const char *) &beta_f32;
slaren1 year ago
Suggested change
alpha = (const char *) &alpha_f32;
beta = (const char *) &beta_f32;
alpha = &alpha_f32;
beta = &beta_f32;
slaren
slaren1 year agoπŸ‘ 2

ggml_cuda_op_mul_mat_cublas should only require this change:

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 4c986db9..553cf9d9 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -7387,7 +7387,7 @@ inline void ggml_cuda_op_mul_mat_cublas(

     const int compute_capability = g_compute_capabilities[id];

-    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
ggerganov Update ggml-cuda.cu
3c8d6b16
ggerganov Update ggml-cuda.cu
30338c56
ggerganov cuda : ggml_cuda_op_mul_mat_cublas support F32 precision
7ea427db
ggerganov cuda : remove oboslete comment
c02412c3
ggerganov
ggerganov approved these changes on 2023-12-18
ggerganov1 year agoπŸŽ‰ 33

Think it should be good to merge

ggerganov ggerganov merged b9e74f9b into master 1 year ago
Slider2k
Slider2k1 year ago

Anyone experienced that the model basically shuts down after long conversations? I.e. starts to produce silence or gibberish. I experience this issue either with the base model or with recently quantized dolphin fine tune.

x4080
x40801 year ago

@Slider2k I just experience it, dont know what happened (using server) and dolphin phi2 and long context also

teleprint-me
teleprint-me1 year ago (edited 1 year ago)

What was the sequence length? PhiForCausalLM models were trained on a 2048 context window. It could be increased with YaRN. Also, the base models aren't fine-tuned for instruct or chat. It's impressive they respond in that format at all.

Slider2k
Slider2k1 year ago

Yep, it dies upon reaching context size limit of 2048. Dolphin fine-tune that I recently tested is instructed in ChatML.
I created an issue #4625 with a log of the llama.cpp session.

teleprint-me
teleprint-me1 year ago (edited 1 year ago)

The solution is pretty simple. Just treat the sequence length (the context window) as a "sliding window". You end up with inf context that way.

Technically, the model won't know anything outside of that window aside from its training, but there are ways around that too. Like I said before, YaRN is another option. Fine-tuning the model is another option.

This isn't specific to Phi models. It's well known. That's why it seems like the model is producing gibberish. It doesn't know any better because you've gone outside of the scope of its training. There's no way around this as the problem isn't with the model or llama.cpp.

Slider2k
Slider2k1 year ago

@teleprint-me I'm sorry, it seems that you don't understand the actual problem here. The Phi-2 model literally breaks, not figuratively. It stops responding after the context limit, or produces random letters. It's far from a normal behavior. And unlike other models when they reach their context limit.

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone