Thanks for this PR.
I think a lot has to be changed (to simplify it).
Also I don't see any core logic to actually handle the fp8, are the kernels ready?
Is it possible to test/add tests ?
189 | /// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead | ||
190 | /// supported for common inference criteria. | ||
191 | #[clap(default_value = "auto", long, env)] | ||
192 | kv_cache_dtype: Option<String>, |
Put a real enum.
enum KvDtype{
Fp8(Path)
}
...
#[clap(long, env, value_enum)]
Option<KvDtype>
This should work much better. None
is equivalent to auto (it just means the user hasn't specified anything we can do whateverwe want with it).
KvDtype
will automatically be sanitized/error checked (String isn't since all strings are available).
I tried putting Fp8(Path)
directly in clap, I'm not sure it actually works in clap internals but this is what we want, if fp8 is chosen we need a path for the scales. and Path
should also ensure the string is a valid path.
Maybe clap doesn't support algebraic enunms and we can't have Fp8(Path)
and need Fp8
instead.
In that case you need to handle validation early (There are other forms of validation in that layer, before pushing the args to the shard).
All CLI validation should happen here, as early as possible, with the best possible error messages.
199 | /// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead | ||
200 | /// supported for common inference criteria. | ||
201 | #[clap(long, env)] | ||
202 | quantization_param_path: Option<String>, |
Why an external file is needed ? Can't the user specify a value ?
If this is linked per model, shouldn't the model config/repo contain that information ?
If it's needed with kvcache=fp8 let's try to make sure it's actually. Ideally it's one option for users, if not possible we need manual validation here (and vaildation can be skipped later)
(I'm trying to avoid adding too many flags, TGI already has too many, and since we don't break, we never remove stuff that was added, that's why if we can read the information from some consistant config in the repo it keeps the interface for the user simpler)
Updated to load the scales from the checkpoint!
16 | from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator | ||
17 | |||
18 | |||
19 | class KVCacheQuantSchema(BaseModel): | ||
20 | dtype: str | ||
21 | # Each key is a TP rank. Each value is a dictionary mapping a TP rank's | ||
22 | # layer indices to their per-tensor KV cache scaling factor. | ||
23 | # TODO: Consider pulling this and its validation methods out into its | ||
24 | # own schema class (tricky as its members are variable) | ||
25 | scaling_factor: Dict[int, Dict[int, float]] | ||
26 | |||
27 | @model_validator(mode="after") | ||
28 | def check_is_fp8(self) -> "KVCacheQuantSchema": | ||
29 | assert self.dtype == "float8_e4m3fn", ( | ||
30 | "Loaded scaling factors intended for KV cache dtype = " | ||
31 | f"{self.dtype} rather than float8_e4m3fn!" | ||
32 | ) | ||
33 | return self |
This should probably be done higher in the stack (ideally in launcher directly).
Rust is much more efficient at running these kind of checks but most importantly errors should happen as early as possible (and launcher has all the user flags too).
55 | return self | ||
56 | |||
57 | @model_validator(mode="after") | ||
58 | def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": |
Why are you not fusing this with the first validator ?
82 | if context: | ||
83 | model_type = context.get("model_type", None) | ||
84 | if model_type is not None: | ||
85 | assert model_type == self.model_type, ( |
I'm wondering what kind of bug this is supposed to prevent.
If the scaling are contained directly within a single repo (so not user supplied) the validity should be obvious (no need for extra keys).
If it is user supplied, well it is unlikely to contain a model_type, no ?
273 | 275 | else: | |
274 | 276 | raise RuntimeError(f"Unknown dtype {dtype}") | |
277 | |||
278 | if kv_cache_dtype not in {"auto", "fp8"}: |
Optional[str]
without validation is acceptable (since the launcher is responsible for validation already).
563 | 568 | quantize=quantize, | |
564 | 569 | speculator=speculator, | |
565 | 570 | dtype=dtype, | |
571 | kv_cache_dtype=kv_cache_dtype, |
Can we avoid passing 2 objects all the time (they seem highly interlinked, so we could probably fuse them in a single object,no ?)
Done
134 | # quantized_value * scaling_factor ~= true_value | ||
135 | # which is consistent with the practice of setting | ||
136 | # scaling_factor = tensor_amax / FPtype_max | ||
137 | self.kv_scale = 1.0 | ||
138 | self.kv_cache_dtype = "auto" |
Huge no-no.
We really don't like mutating self state of models (it makes reasoning about who modified your values (or not) much harder.
You need to be sending the scales directly at load time (possibly even transparent through config or weights)
self.kv_scale = config.get("kv_scale", 1.0)
We do the same for quantize
and speculator
for instance.
Codewise it touches a lot of places, but these are about to be factored out.
Done, the params should be created at load time
387 | 406 | def __init__(self, prefix, config, weights): | |
388 | 407 | super().__init__() | |
389 | 408 | ||
409 | process_group = weights.process_group | ||
410 | self.tp_rank = process_group.rank() | ||
411 | self.tp_world_size = process_group.size() |
Remove all these, state should already be correct.
Done
477 | layer_self_attn = self.model.layers[layer_idx].self_attn | ||
478 | |||
479 | if SYSTEM == "rocm": | ||
480 | # The scaling factor convention we are assuming is | ||
481 | # quantized_value * scaling_factor ~= true_value | ||
482 | # which is consistent with the practice of setting | ||
483 | # scaling_factor = tensor_amax / FPtype_max |
And why does that imply that rocm needs a factor of 2 ?
If the reason is that much global, shouldn't it be handled directly on load?
@Narsil I believe it is related to https://onnx.ai/onnx/technical/float8.html and the diff between e4m3fn and e4m3 (different exponent bias). Is that so @mht-sharma?
But shouldn't it be based on this param https://github.com/vllm-project/vllm/blob/319ad7f1d386699e94f629341c9988a926821f24/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json#L4 ? Also, we have to kind of decide whether we want to follow the scheme in vllm to store quantization params in a json and keep compatibility for e.g. models on the Hub, or not.
Comment from AMD in vllm:
We do *2 only for HIP, to deal with the difference in numeric from our chip. after *2 overall effect is identical as without it on NV.
I will add this a comment in code
Yes, this is the difference between e4m3fn and e4m3 formats. If using a json like the one liked, there should be a check on the dtype
.
483 | # scaling_factor = tensor_amax / FPtype_max | ||
484 | scaling_factor *= 2 | ||
485 | |||
486 | if hasattr(layer_self_attn, "kv_scale"): | ||
487 | layer_self_attn.kv_scale = scaling_factor | ||
488 | logger.info( | ||
489 | f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}" | ||
490 | ) | ||
491 | else: | ||
492 | raise RuntimeError( | ||
493 | "Self attention has no KV cache scaling " "factor attribute!" | ||
494 | ) |
Not needed if done at load time.
692 | 692 | head_size: int, | |
693 | 693 | dtype: torch.dtype, | |
694 | 694 | device: torch.device, | |
695 | kv_cache_dtype: str = "auto", | ||
696 | quantization_param_path: Optional[str] = None, |
Fused?
724 | logger.info(f"Using KV cache data type: {kv_cache_dtype}") | ||
725 | # Currently scaled KV cache is only enabled on ROCm | ||
726 | if quantization_param_path is not None: | ||
727 | if callable(getattr(self.model, "load_kv_cache_scales", None)): | ||
728 | self.model.load_kv_cache_scales(quantization_param_path) |
If done correctly, you don't need any of that because you handle this at load time (and every model will need to be updated to support the new values)
739 | "provided. Defaulting to scaling factors of 1.0. " | ||
740 | "This may lead to less accurate results!" | ||
741 | ) | ||
742 | elif quantization_param_path is not None: |
This should be handled all the way back in the launcher (again ideally directly by clap, if not possible manually in the rust part).
And it should be hard error, not a soft one (paramters sent by the user don't make sense, we never silently ignore.)
29 | 29 | quantize: Optional[str] = None, | |
30 | 30 | speculator: Optional[str] = None, | |
31 | 31 | dtype: Optional[torch.dtype] = None, | |
32 | kv_cache_dtype: Optional[str] = "auto", | ||
33 | quantization_param_path: Optional[str] = None, |
Fused?
31 | layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] | ||
32 | return layer_scales_map.items() | ||
33 | |||
34 | except FileNotFoundError: | ||
35 | logger.error(f"File or directory '{filename}' not found.") | ||
36 | except json.JSONDecodeError: | ||
37 | logger.error(f"Error decoding JSON in file '{filename}'.") | ||
38 | except Exception as e: | ||
39 | logger.error(f"An error occurred while reading '{filename}': {e}") | ||
40 | # This section is reached if and only if any of the excepts are hit | ||
41 | # Return an empty iterable (list) => no KV cache scales are loaded | ||
42 | # which ultimately defaults to 1.0 scales | ||
43 | logger.warning( | ||
44 | "Defaulting to KV cache scaling factors = 1.0 " | ||
45 | f"for all layers in TP rank {tp_rank} " | ||
46 | "as an error occurred during loading." | ||
47 | ) |
Everything here should be hard error (I think the standard ones would do fine).
If a user sends invalid information, we shouldn't silently ignore. They should fix it.
168 | kv_cache[0], | ||
169 | kv_cache[1], | ||
170 | slots, | ||
171 | self.kv_cache_dtype, |
This one isn't necessary here, it should already be inferrable from kv_cache[0]
.
Happy to help with the rebase btw.
nothing shocking to me for benchmarking!
Thanks for the review @Narsil @fxmarty I will rebase and address the comments.
Regarding the format for loading the FP8 scales:
VLLM offers two methods:
quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.
VLLM intends to deprecate the quantization-param-path
method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.
Thanks for the review @Narsil @fxmarty I will rebase and address the comments.
Regarding the format for loading the FP8 scales:
VLLM offers two methods:
- quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
- Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.
VLLM intends to deprecate the
quantization-param-path
method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.
Removed the quantization-param-path
altogether: This method is already deprecated in VLLM, based on discussions here: vllm-project/vllm#4532
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Also I don't see any core logic to actually handle the fp8, are the kernels ready? Is it possible to test/add tests ?
The core logic for handling FP8 is managed by the Paged Attention kernel in VLLM, with the necessary kernel tests. If you have any specific tests in mind, we can discuss them. VLLM includes tests that compare the output with the expected FP8 output, as seen https://github.com/comaniac/vllm/blob/main/tests/models/test_fp8.py. We can add a similar test if required.
Great work, that's cool! I left some general comments, @Narsil may have more about the design
1 | # Accelerating Inference with FP8 KV Cache | ||
2 | |||
3 | Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks |
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks | |
Text Generation Inference (TGI) supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks. |
It would be worth to explain what is FP8 KV Cache
. Readers may not be familiar with it (does it mean attention computation is in fp8? etc)
enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation.
This is kind of vague.
we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
It would be worth IMO to show numbers / a chart here to get a grasp of what greatly means, in which case there is indeed a speedup, etc
We have a couple of options for numbers and charts
Max Batch Total Tokens: The logs display the max batch total tokens, which increase when using the FP8 KV cache. We could create a chart showing the max batch total tokens in both cases (with and without the FP8 KV cache).
Throughput: Currently, I have created a custom script using AsyncClient
to send 500 requests simultaneously with asyncio.gather
. This provides a rough estimate of throughput. @Narsil , do you have any suggestions on calculating throughput more precisely?
22 | ## Current Hardware Support | ||
23 | |||
24 | * Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2. | ||
25 | * AMD GPUs: Supports FP8E4M3. |
Technically, Instinct supports E4M3FNUZ, not E4M3. https://onnx.ai/onnx/technical/float8.html
27 | ## FP8 E5M2 KV Cache | ||
28 | Example usage: | ||
29 | ``` | ||
30 | text-generation-launcher --model-id <> --kv-cache-dtype fp8_e5m2 |
Let's maybe put a full runnable command with docker run
etc? Good inspiration could be https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2
35 | |||
36 | Example usage: | ||
37 | ``` | ||
38 | text-generation-launcher --model-id <> --kv-cache-dtype fp8 |
same + maybe have a model_id which indeed has scaling factors here.
39 | ``` | ||
40 | |||
41 | ### Checkpoint structure for KV scales | ||
42 | The FP8 kv cache scaling factors required in the FP16 checkpoints are specified through the .kv_scale parameter present on the `Attention` module, such as: | ||
43 | |||
44 | ``` | ||
45 | model.layers.0.self_attn.kv_scale < F32 | ||
46 | model.layers.1.self_attn.kv_scale < F32 | ||
47 | ... | ||
48 | ``` |
How do we know whether they are for E4M3FNUZ or E4M3FN format?
I guess the current tools Nvidia AMMO and AutoFP8 uses E4M3FN
. Currently there is no flag to determine the format (unless ofcourse the weight is quantized). But I could add a check and add a parameter for this in the checkpoint.
Updated reading this from config. See: a7909e6
51 | |||
52 | Use [AutoFP8](https://github.com/neuralmagic/AutoFP8) with calibration data to generate per-tensor scales for FP8 quantized KV Cache. For more details, see the following example: https://github.com/neuralmagic/AutoFP8/blob/main/example_dataset.py | ||
53 | |||
54 | TGI provides a utility to extract the FP8 KV cache scales from an `AutoFP8` quantized model and save them to the FP16 model for use with TGI. For more information: <path to script> |
Todo
87 | 87 | [env: DTYPE=] | |
88 | 88 | [possible values: float16, bfloat16] | |
89 | 89 | ||
90 | ``` | ||
91 | ## KV_CACHE_DTYPE | ||
92 | ```shell | ||
93 | --kv-cache-dtype <KV_CACHE_DTYPE> | ||
94 | [env: KV_CACHE_DTYPE=] | ||
95 | [possible values: fp8, fp8_e5m2] | ||
96 |
Needs a description
80 | 83 | num_kv_heads=model.model.num_key_value_heads, | |
81 | 84 | head_size=model.model.head_size, | |
82 | dtype=dtype, | ||
85 | dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype, |
Things are a bit harder to read if dtype
attribute is used to mean the KV cache storage pytorch dtype. For some other models (gemma, idefics), self.dtype
is used with an other meaning, being the weights dtype.
791 | "Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache" | ||
792 | ) | ||
793 | |||
794 | # ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3. | ||
795 | # The multiplication by 2 compensates for the different numeric representation | ||
796 | # between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms. | ||
797 | # After this adjustment, the overall effect is equivalent to the scaling applied without | ||
798 | # it on Nvidia GPUs. | ||
799 | if SYSTEM == "rocm": | ||
800 | kv_scale *= 2.0 |
How do we know whether serialized scales are in E4M3FN or E4M3FNUZ format? I think depending on that, the logic should be different here.
Updated the code to read the corresponding dtype as kv_cache_torch_dtype
from config.
Added the format in the README.md and utility script to add kv_cache_torch_dtype
when quantising model
See: a7909e6
32 | Path to save the FP16 model with the kv scales | ||
33 | ``` | ||
34 | |||
35 | ## Example usage | ||
36 | To extract KV cache scaling factors from a quantized FP8 model and save them to an unquantized FP16 model, use the following command: | ||
37 | |||
38 | ``` | ||
39 | python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct | ||
40 | ``` |
If we target vllm/tgi intercompatibility, why couldn't we load directly e.g. neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV
in TGI? Is it not loadable in vllm?
Also, by quantized FP8 model
, do you mean a model whose weights are quantized to fp8? How does it relate to FP8 KV cache? To me to obtain the KV cache scales you would simply need to have calibration data passing through the network & collecting stats on the KV cache.
It feels like the the KV cache scales from a quantized model (like neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV
), whose KV cache scales may have been obtained after the weights quantization (?), may have an unnecessary bias due to being inferred from calibration on the quantized model, not the unquantized one.
The model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV
has both the weights quantized along with KV scales. However, I don't think we can load the FP8 weights in TGI yet. And also one would run an FP16 model with FP8 KV cache so we need such a checkpoint
You are right; it may have an additional bias due to calibration after weight quantization. The bias might have been low, so I couldn't find a noticeable difference during inference.
I have tested two other models generated using scales from Nvidia AMMO (VLLM uses this also), which may not have this bias. The quantizer can accept the quantized format as full precision and provide the model with KV scales. We can extract scales from that.
Here is the link to the models: Llama-2-70b-chat-hf-FP8-KV
I will add these models and provide an example for this.
Updated the example with Nvidia AMMO
1 | # FP8 (fp8_e4m3) KV Cache Scaling Factor Extraction Utility | ||
2 | |||
3 | This utility is designed to extract KV cache scaling factors from a quantized `FP8(fp8_e4m3)` Hugging Face (HF) model. The extracted scaling factors are then saved to the corresponding unquantized HF model, which can be used with Text Generation Inference (TGI). |
Do you have an example of such model?
Updated and added a model to readme.
1 | # Accelerating Inference with FP8 KV Cache |
I think it would be worth to have some evaluation metrics for an example model
248 | 248 | ||
249 | FP8_KVCACHE_SUPPORTED_MODELS = { | ||
250 | "llama", | ||
251 | "baichun", |
is this a typo of "baichuan"?
Yes. Thanks for pointing it out
799 | f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'." | ||
800 | ) | ||
801 | |||
802 | # ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3. |
Is it should be "ROCm uses FP8 format with fp8_e4m3fnuz, whereas NVIDIA GPU uses fp8_e4m3fn"?
Yes that's correct
Closing this as we added support for FP8 kv cache support in #2603.
More support is coming (for pre-scaled kv-cache fp8)
Login to write a write a comment.
What does this PR do?
This PR introduces support for FP8 KV Cache in Text Generation Inference (TGI), significantly enhancing performance and memory efficiency on both Nvidia and AMD GPUs. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint, leading to faster and more scalable text generation.
Hardware Compatibility:
Example Usage:
KV cache scaling factors should be included in the FP16 checkpoint for E4M3 format to maintain accuracy. Default scaling factor is set to 1.0 if not provided, which may lead to accuracy loss.
Checkpoint Structure for KV Scales:
The FP8 KV cache scaling factors are specified through the
.kv_scale
parameter in the attention moduleThis follows a structure proposed in vllm - https://docs.vllm.ai/en/stable/quantization/fp8.html#fp8-checkpoint-structure-explanation
When providing
.kv_scale
in model, the config should specify properkv_cache_torch_dtype
used to generate scales (float8_e4m3fn
orfloat8_e4m3fnuz
).Currently, users need to extract the KV scales from FP8 checkpoint and add to the FP16 model. A helper script is provided in the PR for the same.
Sample Models with KV scales: Models with FP8 KV Cache
Todos: