llama.cpp
server : avoid breaking KV cache when prompt >= n_ctx (#6855)
#8359
Closed

server : avoid breaking KV cache when prompt >= n_ctx (#6855) #8359

prfd wants to merge 1 commit into ggml-org:master from server
prfd
prfd317 days ago๐Ÿ‘ 1

That's my unrushed look into my previous attempt on #6855,

As described before, when the context exceeds n_ctx, shifting with a proper offset only works during decoding by using n_discard. If the context exceeds n_ctx before decoding, the context is truncated by half, which breaks the KV cache. This pull addresses the issue by adding a second tweakable offset named n_truncate, which functions similarly to n_discard. This allows the client to decide which offset to use depending on the situation.

github-actions github-actions added examples
github-actions github-actions added python
github-actions github-actions added server
prfd prfd force pushed from c03a78d1 to 559c98dd 316 days ago
ggerganov
ggerganov commented on 2024-07-09
ggerganov316 days ago๐Ÿ‘ 1

The context shifting logic has become quite complicated and it's hard to review the changes. If more people can confirm that this works as expected we can merge. Otherwise, it will have to wait for eventual refactoring of the logic

HanClinto
HanClinto commented on 2024-07-10
examples/server/tests/features/truncation.feature
1# run with: ./tests.sh --no-skipped --tags truncation
2
@trucation
HanClinto314 days ago๐Ÿ˜„ 2

I almost couldn't figure out why I couldn't get this new test to run.

Suggested change
@trucation
@truncation

Finally turned to Claude who helped me track down this typo. Once corrected, I was able to run this test scenario with the command ./tests.sh --no-skipped --tags truncation as-given.

HanClinto
HanClinto314 days ago๐Ÿ‘ 2

I also am not able to easily follow the logic, but I can at least confirm that when I run the included tests against master, then the tests fail as-expected:

Failing test log run against `master`
./tests.sh --no-skipped --tags truncation
@truncation @slow
Feature: Chat truncation # features/truncation.feature:4
Starting new scenario: Correctly truncate the prompt when the prompt exceeds the context size!

  Background: Server startup  # features/truncation.feature:6

  Scenario: Correctly truncate the prompt when the prompt exceeds the context size  # features/truncation.feature:17
    Given a server listening on localhost:8080                                      # features/steps/steps.py:25 0.001s
    And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models   # features/steps/steps.py:82 0.000s
    And prompt caching is enabled                                                   # features/steps/steps.py:154 0.000s
    And a list of stop strings ["\n"]                                               # features/steps/steps.py:431 0.000s
    And 82 tokens to keep                                                           # features/steps/steps.py:358 0.000s
    And 256 KV cache size                                                           # features/steps/steps.py:129 0.000s
    And 32 server max tokens to predict                                             # features/steps/steps.py:139 0.000s
    Then the server is starting                                                     # features/steps/steps.py:174 0.117s
    Then the server is healthy                                                      # features/steps/steps.py:198 4.813s
    Given a prompt                                                                  # features/steps/steps.py:525
    Given a prompt                                                                  # features/steps/steps.py:525 0.000s
      """tinue the chat below.
      Continue the chat below.going?
      Me: Hey there, how's it going?r asking! How are you?
      You: I'm doing well, thanks for asking! How are you?e. How's your day?
      Me: I'm doing good, just trying to get some work done. How's your day?new project.
      You: My day has been pretty productive so far. I've been working on a new project.
      Me: That's great to hear! What's the new project you're working on?heir personal finances. I'm really excited about it.
      You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it.have it ready to launch?
      Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch?
      You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it.
      Me: That's really nice, are you happy with the progress so far?
      """
      """
    And an ongoing completion request                                               # features/steps/steps.py:250 1.880s
    Then -1 tokens are predicted matching You:                                      # features/steps/steps.py:282 0.000s
    Given an ongoing prompt                                                         # features/steps/steps.py:547
    Given an ongoing prompt                                                         # features/steps/steps.py:547 0.000s
      """
      Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
      Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
      """
      """
    And 52 tokens to truncate                                                       # features/steps/steps.py:363 0.000s
    And a completion request with no api error                                      # features/steps/steps.py:244 1.770s
    Then -1 tokens are predicted matching You:                                      # features/steps/steps.py:282 0.000s
    And 28 prompt tokens are processed                                              # features/steps/steps.py:332 0.000s
      Assertion Failed: n_prompt=102
      Captured stdout:
      bench: starting server with: ../../../build/bin/llama-server --host localhost --port 8080 --model mistral-7b-v0.2-iq3_s-imat.gguf --hf-repo ggml-org/models --hf-file mistral-7b-v0.2-iq3_s-imat.gguf --ctx-size 256 --n-predict 32 --log-format text
      server pid=92566, behave pid=92502
      waiting for server to start, connect error code = 61...
      INFO [                    main] build info | tid="0x1e89e4c00" timestamp=1720645902 build=3370 commit="587f058b"
      INFO [                    main] system info | tid="0x1e89e4c00" timestamp=1720645902 n_threads=8 n_threads_batch=-1 total_threads=10 system_info="AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 0 | NEON = 1 | SVE = 0 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 0 | "
      server started!
      INFO [                    init] initializing slots | tid="0x1e89e4c00" timestamp=1720645907 n_slots=1
      INFO [                    init] new slot | tid="0x1e89e4c00" timestamp=1720645907 id_slot=0 n_ctx_slot=256
      INFO [                    main] model loaded | tid="0x1e89e4c00" timestamp=1720645907
      INFO [                    main] chat template | tid="0x1e89e4c00" timestamp=1720645907 chat_example="<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\n" built_in=true
      INFO [                    main] HTTP server listening | tid="0x1e89e4c00" timestamp=1720645907 port="8080" n_threads_http="9" hostname="localhost"
      INFO [            update_slots] all slots are idle | tid="0x1e89e4c00" timestamp=1720645907
      INFO [     process_single_task] slot data | tid="0x1e89e4c00" timestamp=1720645907 id_task=0 n_idle_slots=1 n_processing_slots=0
      INFO [            update_slots] all slots are idle | tid="0x1e89e4c00" timestamp=1720645907
      INFO [      log_server_request] request | tid="0x16c11f000" timestamp=1720645907 remote_addr="127.0.0.1" remote_port=50382 status=200 method="GET" path="/health" params={}
      INFO [   launch_slot_with_task] slot is processing task | tid="0x1e89e4c00" timestamp=1720645907 id_slot=0 id_task=1
      INFO [            update_slots] kv cache rm [p0, end) | tid="0x1e89e4c00" timestamp=1720645907 id_slot=0 id_task=1 p0=0
      INFO [           print_timings] prompt eval time     =    1048.86 ms /   221 tokens (    4.75 ms per token,   210.70 tokens per second) | tid="0x1e89e4c00" timestamp=1720645909 id_slot=0 id_task=1 t_prompt_processing=1048.86 n_prompt_tokens_processed=221 t_token=4.745972850678733 n_tokens_second=210.70495585683506
      INFO [           print_timings] generation eval time =     828.09 ms /    23 runs   (   36.00 ms per token,    27.77 tokens per second) | tid="0x1e89e4c00" timestamp=1720645909 id_slot=0 id_task=1 t_token_generation=828.088 n_decoded=23 t_token=36.00382608695652 n_tokens_second=27.774825863917847
      INFO [           print_timings]           total time =    1876.95 ms | tid="0x1e89e4c00" timestamp=1720645909 id_slot=0 id_task=1 t_prompt_processing=1048.86 t_token_generation=828.088 t_total=1876.9479999999999
      INFO [            update_slots] slot released | tid="0x1e89e4c00" timestamp=1720645909 id_slot=0 id_task=1 n_ctx=256 n_past=243 n_system_tokens=0 n_cache_tokens=243 truncated=false
      INFO [            update_slots] all slots are idle | tid="0x1e89e4c00" timestamp=1720645909
      INFO [      log_server_request] request | tid="0x16c1ab000" timestamp=1720645909 remote_addr="127.0.0.1" remote_port=50385 status=200 method="POST" path="/completion" params={}
      INFO [   launch_slot_with_task] slot is processing task | tid="0x1e89e4c00" timestamp=1720645909 id_slot=0 id_task=25
      INFO [            update_slots] kv cache rm [p0, end) | tid="0x1e89e4c00" timestamp=1720645909 id_slot=0 id_task=25 p0=82
      INFO [           print_timings] prompt eval time     =     594.46 ms /   102 tokens (    5.83 ms per token,   171.58 tokens per second) | tid="0x1e89e4c00" timestamp=1720645911 id_slot=0 id_task=25 t_prompt_processing=594.463 n_prompt_tokens_processed=102 t_token=5.82806862745098 n_tokens_second=171.58342907800824
      INFO [           print_timings] generation eval time =    1173.25 ms /    32 runs   (   36.66 ms per token,    27.27 tokens per second) | tid="0x1e89e4c00" timestamp=1720645911 id_slot=0 id_task=25 t_token_generation=1173.253 n_decoded=32 t_token=36.66415625 n_tokens_second=27.274594652645252
      INFO [           print_timings]           total time =    1767.72 ms | tid="0x1e89e4c00" timestamp=1720645911 id_slot=0 id_task=25 t_prompt_processing=594.463 t_token_generation=1173.253 t_total=1767.716
      INFO [            update_slots] slot released | tid="0x1e89e4c00" timestamp=1720645911 id_slot=0 id_task=25 n_ctx=256 n_past=215 n_system_tokens=0 n_cache_tokens=215 truncated=true
      INFO [            update_slots] all slots are idle | tid="0x1e89e4c00" timestamp=1720645911
      INFO [      log_server_request] request | tid="0x16c2c3000" timestamp=1720645911 remote_addr="127.0.0.1" remote_port=50389 status=200 method="POST" path="/completion" params={}

      Captured stderr:
      llama_download_file: previous metadata file found mistral-7b-v0.2-iq3_s-imat.gguf.json: {"etag":"\"21686aae6bc0a7a16b85e24c034ae14a-199\"","lastModified":"Mon, 25 Mar 2024 13:07:43 GMT","url":"https://huggingface.co/ggml-org/models/resolve/main/mistral-7b-v0.2-iq3_s-imat.gguf"}
      llama_model_loader: loaded meta data with 23 key-value pairs and 291 tensors from mistral-7b-v0.2-iq3_s-imat.gguf (version GGUF V3 (latest))
      llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
      llama_model_loader: - kv   0:                       general.architecture str              = llama
      llama_model_loader: - kv   1:                               general.name str              = huggingface
      llama_model_loader: - kv   2:                           llama.vocab_size u32              = 32000
      llama_model_loader: - kv   3:                       llama.context_length u32              = 32768
      llama_model_loader: - kv   4:                     llama.embedding_length u32              = 4096
      llama_model_loader: - kv   5:                          llama.block_count u32              = 32
      llama_model_loader: - kv   6:                  llama.feed_forward_length u32              = 14336
      llama_model_loader: - kv   7:                 llama.rope.dimension_count u32              = 128
      llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 32
      llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 8
      llama_model_loader: - kv  10:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
      llama_model_loader: - kv  11:                       llama.rope.freq_base f32              = 1000000.000000
      llama_model_loader: - kv  12:                          general.file_type u32              = 26
      llama_model_loader: - kv  13:                       tokenizer.ggml.model str              = llama
      llama_model_loader: - kv  14:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
      llama_model_loader: - kv  15:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
      llama_model_loader: - kv  16:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
      llama_model_loader: - kv  17:                tokenizer.ggml.bos_token_id u32              = 1
      llama_model_loader: - kv  18:                tokenizer.ggml.eos_token_id u32              = 2
      llama_model_loader: - kv  19:            tokenizer.ggml.unknown_token_id u32              = 0
      llama_model_loader: - kv  20:               tokenizer.ggml.add_bos_token bool             = true
      llama_model_loader: - kv  21:               tokenizer.ggml.add_eos_token bool             = false
      llama_model_loader: - kv  22:               general.quantization_version u32              = 2
      llama_model_loader: - type  f32:   65 tensors
      llama_model_loader: - type q4_K:   32 tensors
      llama_model_loader: - type q6_K:    1 tensors
      llama_model_loader: - type iq3_s:  193 tensors
      llm_load_vocab: special tokens cache size = 259
      llm_load_vocab: token to piece cache size = 0.1637 MB
      llm_load_print_meta: format           = GGUF V3 (latest)
      llm_load_print_meta: arch             = llama
      llm_load_print_meta: vocab type       = SPM
      llm_load_print_meta: n_vocab          = 32000
      llm_load_print_meta: n_merges         = 0
      llm_load_print_meta: vocab_only       = 0
      llm_load_print_meta: n_ctx_train      = 32768
      llm_load_print_meta: n_embd           = 4096
      llm_load_print_meta: n_layer          = 32
      llm_load_print_meta: n_head           = 32
      llm_load_print_meta: n_head_kv        = 8
      llm_load_print_meta: n_rot            = 128
      llm_load_print_meta: n_swa            = 0
      llm_load_print_meta: n_embd_head_k    = 128
      llm_load_print_meta: n_embd_head_v    = 128
      llm_load_print_meta: n_gqa            = 4
      llm_load_print_meta: n_embd_k_gqa     = 1024
      llm_load_print_meta: n_embd_v_gqa     = 1024
      llm_load_print_meta: f_norm_eps       = 0.0e+00
      llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
      llm_load_print_meta: f_clamp_kqv      = 0.0e+00
      llm_load_print_meta: f_max_alibi_bias = 0.0e+00
      llm_load_print_meta: f_logit_scale    = 0.0e+00
      llm_load_print_meta: n_ff             = 14336
      llm_load_print_meta: n_expert         = 0
      llm_load_print_meta: n_expert_used    = 0
      llm_load_print_meta: causal attn      = 1
      llm_load_print_meta: pooling type     = 0
      llm_load_print_meta: rope type        = 0
      llm_load_print_meta: rope scaling     = linear
      llm_load_print_meta: freq_base_train  = 1000000.0
      llm_load_print_meta: freq_scale_train = 1
      llm_load_print_meta: n_ctx_orig_yarn  = 32768
      llm_load_print_meta: rope_finetuned   = unknown
      llm_load_print_meta: ssm_d_conv       = 0
      llm_load_print_meta: ssm_d_inner      = 0
      llm_load_print_meta: ssm_d_state      = 0
      llm_load_print_meta: ssm_dt_rank      = 0
      llm_load_print_meta: model type       = 7B
      llm_load_print_meta: model ftype      = IQ3_S - 3.4375 bpw
      llm_load_print_meta: model params     = 7.24 B
      llm_load_print_meta: model size       = 2.96 GiB (3.51 BPW)
      llm_load_print_meta: general.name     = huggingface
      llm_load_print_meta: BOS token        = 1 '<s>'
      llm_load_print_meta: EOS token        = 2 '</s>'
      llm_load_print_meta: UNK token        = 0 '<unk>'
      llm_load_print_meta: LF token         = 13 '<0x0A>'
      llm_load_print_meta: max token length = 48
      llm_load_tensors: ggml ctx size =    0.27 MiB
      ggml_backend_metal_log_allocated_size: allocated buffer, size =  2980.56 MiB, ( 2980.62 / 10922.67)
      llm_load_tensors: offloading 32 repeating layers to GPU
      llm_load_tensors: offloading non-repeating layers to GPU
      llm_load_tensors: offloaded 33/33 layers to GPU
      llm_load_tensors:        CPU buffer size =    53.71 MiB
      llm_load_tensors:      Metal buffer size =  2980.56 MiB
      .................................................................................................
      llama_new_context_with_model: n_ctx      = 256
      llama_new_context_with_model: n_batch    = 256
      llama_new_context_with_model: n_ubatch   = 256
      llama_new_context_with_model: flash_attn = 0
      llama_new_context_with_model: freq_base  = 1000000.0
      llama_new_context_with_model: freq_scale = 1
      ggml_metal_init: allocating
      ggml_metal_init: found device: Apple M1 Pro
      ggml_metal_init: picking default device: Apple M1 Pro
      ggml_metal_init: using embedded metal library
      ggml_metal_init: GPU name:   Apple M1 Pro
      ggml_metal_init: GPU family: MTLGPUFamilyApple7  (1007)
      ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
      ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
      ggml_metal_init: simdgroup reduction support   = true
      ggml_metal_init: simdgroup matrix mul. support = true
      ggml_metal_init: hasUnifiedMemory              = true
      ggml_metal_init: recommendedMaxWorkingSetSize  = 11453.25 MB
      llama_kv_cache_init:      Metal KV buffer size =    32.00 MiB
      llama_new_context_with_model: KV self size  =   32.00 MiB, K (f16):   16.00 MiB, V (f16):   16.00 MiB
      llama_new_context_with_model:        CPU  output buffer size =     0.24 MiB
      llama_new_context_with_model:      Metal compute buffer size =    40.25 MiB
      llama_new_context_with_model:        CPU compute buffer size =     4.25 MiB
      llama_new_context_with_model: graph nodes  = 1030
      llama_new_context_with_model: graph splits = 2

server is listening on localhost:8080...
shutting down server pid=92566 ...


Failing scenarios:
  features/truncation.feature:17  Correctly truncate the prompt when the prompt exceeds the context size

0 features passed, 1 failed, 9 skipped
0 scenarios passed, 1 failed, 51 skipped
16 steps passed, 1 failed, 834 skipped, 0 undefined
Took 0m8.583s

And when run within the PR (with above typo corrected), then it passes as-expected:

Passing test log run within PR
% ./tests.sh --no-skipped --tags truncation
@truncation @slow
Feature: Chat truncation # features/truncation.feature:4
Starting new scenario: Correctly truncate the prompt when the prompt exceeds the context size!

  Background: Server startup  # features/truncation.feature:6

  Scenario: Correctly truncate the prompt when the prompt exceeds the context size  # features/truncation.feature:17
    Given a server listening on localhost:8080                                      # features/steps/steps.py:25 0.001s
    And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models   # features/steps/steps.py:82 0.000s
    And prompt caching is enabled                                                   # features/steps/steps.py:154 0.000s
    And a list of stop strings ["\n"]                                               # features/steps/steps.py:431 0.000s
    And 82 tokens to keep                                                           # features/steps/steps.py:358 0.000s
    And 256 KV cache size                                                           # features/steps/steps.py:129 0.000s
    And 32 server max tokens to predict                                             # features/steps/steps.py:139 0.000s
    Then the server is starting                                                     # features/steps/steps.py:174 0.634s
    Then the server is healthy                                                      # features/steps/steps.py:198 5.603s
    Given a prompt                                                                  # features/steps/steps.py:525
    Given a prompt                                                                  # features/steps/steps.py:525 0.000s
      """tinue the chat below.
      Continue the chat below.going?
      Me: Hey there, how's it going?r asking! How are you?
      You: I'm doing well, thanks for asking! How are you?e. How's your day?
      Me: I'm doing good, just trying to get some work done. How's your day?new project.
      You: My day has been pretty productive so far. I've been working on a new project.
      Me: That's great to hear! What's the new project you're working on?heir personal finances. I'm really excited about it.
      You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it.have it ready to launch?
      Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch?
      You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it.
      Me: That's really nice, are you happy with the progress so far?
      """
      """
    And an ongoing completion request                                               # features/steps/steps.py:250 1.879s
    Then -1 tokens are predicted matching You:                                      # features/steps/steps.py:282 0.000s
    Given an ongoing prompt                                                         # features/steps/steps.py:547
    Given an ongoing prompt                                                         # features/steps/steps.py:547 0.000s
      """
      Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
      Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
      """
      """
    And 52 tokens to truncate                                                       # features/steps/steps.py:363 0.000s
    And a completion request with no api error                                      # features/steps/steps.py:244 1.370s
    Then -1 tokens are predicted matching You:                                      # features/steps/steps.py:282 0.000s
    And 28 prompt tokens are processed                                              # features/steps/steps.py:332 0.000s
shutting down server pid=95769 ...

1 feature passed, 0 failed, 9 skipped
1 scenario passed, 0 failed, 51 skipped
17 steps passed, 0 failed, 834 skipped, 0 undefined
Took 0m9.488s

It's a bit beyond me to confirm any more than that at this point, but I thought this was a useful datapoint worth sharing.

server : avoid breaking KV cache when prompt >= n_ctx (#6855)
e6a5a6c6
prfd prfd force pushed from 559c98dd to e6a5a6c6 314 days ago
mofosyne mofosyne added Review Complexity : Medium
prfd prfd closed this 46 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone