Add FP8 KVCache support #2028

mht-sharma wants to merge 22 commits into main from fp8_kvcache
mht-sharma
mht-sharma346 days ago (edited 320 days ago)

What does this PR do?

This PR introduces support for FP8 KV Cache in Text Generation Inference (TGI), significantly enhancing performance and memory efficiency on both Nvidia and AMD GPUs. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint, leading to faster and more scalable text generation.

Hardware Compatibility:

  • Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2 (TODO: Need VLLM update).
  • AMD GPUs: Supports FP8E4M3.

Example Usage:

text-generation-launcher --model-id <model_id> --kv-cache-dtype fp8/fp8_e5m2

KV cache scaling factors should be included in the FP16 checkpoint for E4M3 format to maintain accuracy. Default scaling factor is set to 1.0 if not provided, which may lead to accuracy loss.

Checkpoint Structure for KV Scales:

The FP8 KV cache scaling factors are specified through the .kv_scale parameter in the attention module

model.layers.0.self_attn.kv_scale                < F32
model.layers.1.self_attn.kv_scale                < F32

This follows a structure proposed in vllm - https://docs.vllm.ai/en/stable/quantization/fp8.html#fp8-checkpoint-structure-explanation

When providing .kv_scale in model, the config should specify proper kv_cache_torch_dtype used to generate scales (float8_e4m3fn or float8_e4m3fnuz).

Currently, users need to extract the KV scales from FP8 checkpoint and add to the FP16 model. A helper script is provided in the PR for the same.

Sample Models with KV scales: Models with FP8 KV Cache

Todos:

  • Documentation
  • Tests
  • Update VLLM for CUDA to support E5M2. @Narsil could you help with this!
  • Only supports LLAMA, will update same for other models in this or other PRs
add kvcache fp8 support
8c437a80
Narsil
Narsil commented on 2024-06-06
Narsil346 days ago

Thanks for this PR.

I think a lot has to be changed (to simplify it).

Also I don't see any core logic to actually handle the fp8, are the kernels ready?
Is it possible to test/add tests ?

launcher/src/main.rs
189 /// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
190 /// supported for common inference criteria.
191 #[clap(default_value = "auto", long, env)]
192
kv_cache_dtype: Option<String>,
Narsil346 days ago

Put a real enum.

enum KvDtype{
    Fp8(Path)
}

...
#[clap(long, env, value_enum)]
Option<KvDtype>

This should work much better. None is equivalent to auto (it just means the user hasn't specified anything we can do whateverwe want with it).

KvDtype will automatically be sanitized/error checked (String isn't since all strings are available).

I tried putting Fp8(Path) directly in clap, I'm not sure it actually works in clap internals but this is what we want, if fp8 is chosen we need a path for the scales. and Path should also ensure the string is a valid path.

Maybe clap doesn't support algebraic enunms and we can't have Fp8(Path) and need Fp8 instead.
In that case you need to handle validation early (There are other forms of validation in that layer, before pushing the args to the shard).

All CLI validation should happen here, as early as possible, with the best possible error messages.

launcher/src/main.rs
199 /// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
200 /// supported for common inference criteria.
201 #[clap(long, env)]
202
quantization_param_path: Option<String>,
Narsil346 days ago

Why an external file is needed ? Can't the user specify a value ?
If this is linked per model, shouldn't the model config/repo contain that information ?

If it's needed with kvcache=fp8 let's try to make sure it's actually. Ideally it's one option for users, if not possible we need manual validation here (and vaildation can be skipped later)

Narsil346 days ago

(I'm trying to avoid adding too many flags, TGI already has too many, and since we don't break, we never remove stuff that was added, that's why if we can read the information from some consistant config in the repo it keeps the interface for the user simpler)

mht-sharma328 days ago

Updated to load the scales from the checkpoint!

server/text_generation_server/layers/schema.py
16from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
17
18
19
class KVCacheQuantSchema(BaseModel):
20
dtype: str
21
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
22
# layer indices to their per-tensor KV cache scaling factor.
23
# TODO: Consider pulling this and its validation methods out into its
24
# own schema class (tricky as its members are variable)
25
scaling_factor: Dict[int, Dict[int, float]]
26
27
@model_validator(mode="after")
28
def check_is_fp8(self) -> "KVCacheQuantSchema":
29
assert self.dtype == "float8_e4m3fn", (
30
"Loaded scaling factors intended for KV cache dtype = "
31
f"{self.dtype} rather than float8_e4m3fn!"
32
)
33
return self
Narsil346 days ago

This should probably be done higher in the stack (ideally in launcher directly).

Rust is much more efficient at running these kind of checks but most importantly errors should happen as early as possible (and launcher has all the user flags too).

server/text_generation_server/layers/schema.py
55 return self
56
57 @model_validator(mode="after")
58
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
Narsil346 days ago

Why are you not fusing this with the first validator ?

server/text_generation_server/layers/schema.py
82 if context:
83 model_type = context.get("model_type", None)
84 if model_type is not None:
85
assert model_type == self.model_type, (
Narsil346 days ago

I'm wondering what kind of bug this is supposed to prevent.

If the scaling are contained directly within a single repo (so not user supplied) the validity should be obvious (no need for extra keys).

If it is user supplied, well it is unlikely to contain a model_type, no ?

server/text_generation_server/models/__init__.py
273275 else:
274276 raise RuntimeError(f"Unknown dtype {dtype}")
277
278
if kv_cache_dtype not in {"auto", "fp8"}:
Narsil346 days ago

Optional[str] without validation is acceptable (since the launcher is responsible for validation already).

server/text_generation_server/models/__init__.py
563568 quantize=quantize,
564569 speculator=speculator,
565570 dtype=dtype,
571
kv_cache_dtype=kv_cache_dtype,
Narsil346 days ago

Can we avoid passing 2 objects all the time (they seem highly interlinked, so we could probably fuse them in a single object,no ?)

mht-sharma328 days ago

Done

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
134 # quantized_value * scaling_factor ~= true_value
135 # which is consistent with the practice of setting
136 # scaling_factor = tensor_amax / FPtype_max
137
self.kv_scale = 1.0
138
self.kv_cache_dtype = "auto"
Narsil346 days ago

Huge no-no.

We really don't like mutating self state of models (it makes reasoning about who modified your values (or not) much harder.

You need to be sending the scales directly at load time (possibly even transparent through config or weights)

Narsil346 days ago (edited 346 days ago)
self.kv_scale = config.get("kv_scale", 1.0)
Narsil346 days ago

We do the same for quantize and speculator for instance.

Codewise it touches a lot of places, but these are about to be factored out.

mht-sharma328 days ago

Done, the params should be created at load time

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
387406 def __init__(self, prefix, config, weights):
388407 super().__init__()
389408
409
process_group = weights.process_group
410
self.tp_rank = process_group.rank()
411
self.tp_world_size = process_group.size()
Narsil346 days ago

Remove all these, state should already be correct.

mht-sharma328 days ago

Done

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
477 layer_self_attn = self.model.layers[layer_idx].self_attn
478
479 if SYSTEM == "rocm":
480
# The scaling factor convention we are assuming is
481
# quantized_value * scaling_factor ~= true_value
482
# which is consistent with the practice of setting
483
# scaling_factor = tensor_amax / FPtype_max
Narsil346 days ago

And why does that imply that rocm needs a factor of 2 ?

If the reason is that much global, shouldn't it be handled directly on load?

fxmarty338 days ago

@Narsil I believe it is related to https://onnx.ai/onnx/technical/float8.html and the diff between e4m3fn and e4m3 (different exponent bias). Is that so @mht-sharma?

But shouldn't it be based on this param https://github.com/vllm-project/vllm/blob/319ad7f1d386699e94f629341c9988a926821f24/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json#L4 ? Also, we have to kind of decide whether we want to follow the scheme in vllm to store quantization params in a json and keep compatibility for e.g. models on the Hub, or not.

mht-sharma338 days ago (edited 338 days ago)

Comment from AMD in vllm:

We do *2 only for HIP, to deal with the difference in numeric from our chip. after *2 overall effect is identical as without it on NV.

I will add this a comment in code

fxmarty338 days ago

Yes, this is the difference between e4m3fn and e4m3 formats. If using a json like the one liked, there should be a check on the dtype.

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
483 # scaling_factor = tensor_amax / FPtype_max
484 scaling_factor *= 2
485
486
if hasattr(layer_self_attn, "kv_scale"):
487
layer_self_attn.kv_scale = scaling_factor
488
logger.info(
489
f"Loaded KV cache scaling factor for layer {layer_idx}: {scaling_factor}"
490
)
491
else:
492
raise RuntimeError(
493
"Self attention has no KV cache scaling " "factor attribute!"
494
)
Narsil346 days ago

Not needed if done at load time.

server/text_generation_server/models/flash_causal_lm.py
692692 head_size: int,
693693 dtype: torch.dtype,
694694 device: torch.device,
695
kv_cache_dtype: str = "auto",
696
quantization_param_path: Optional[str] = None,
Narsil346 days ago

Fused?

server/text_generation_server/models/flash_causal_lm.py
724 logger.info(f"Using KV cache data type: {kv_cache_dtype}")
725 # Currently scaled KV cache is only enabled on ROCm
726 if quantization_param_path is not None:
727
if callable(getattr(self.model, "load_kv_cache_scales", None)):
728
self.model.load_kv_cache_scales(quantization_param_path)
Narsil346 days ago

If done correctly, you don't need any of that because you handle this at load time (and every model will need to be updated to support the new values)

server/text_generation_server/models/flash_causal_lm.py
739 "provided. Defaulting to scaling factors of 1.0. "
740 "This may lead to less accurate results!"
741 )
742
elif quantization_param_path is not None:
Narsil346 days ago

This should be handled all the way back in the launcher (again ideally directly by clap, if not possible manually in the rust part).

And it should be hard error, not a soft one (paramters sent by the user don't make sense, we never silently ignore.)

server/text_generation_server/models/flash_llama.py
2929 quantize: Optional[str] = None,
3030 speculator: Optional[str] = None,
3131 dtype: Optional[torch.dtype] = None,
32
kv_cache_dtype: Optional[str] = "auto",
33
quantization_param_path: Optional[str] = None,
Narsil346 days ago

Fused?

server/text_generation_server/utils/weights_utils.py
31 layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
32 return layer_scales_map.items()
33
34
except FileNotFoundError:
35
logger.error(f"File or directory '{filename}' not found.")
36
except json.JSONDecodeError:
37
logger.error(f"Error decoding JSON in file '{filename}'.")
38
except Exception as e:
39
logger.error(f"An error occurred while reading '{filename}': {e}")
40
# This section is reached if and only if any of the excepts are hit
41
# Return an empty iterable (list) => no KV cache scales are loaded
42
# which ultimately defaults to 1.0 scales
43
logger.warning(
44
"Defaulting to KV cache scaling factors = 1.0 "
45
f"for all layers in TP rank {tp_rank} "
46
"as an error occurred during loading."
47
)
Narsil346 days ago

Everything here should be hard error (I think the standard ones would do fine).

If a user sends invalid information, we shouldn't silently ignore. They should fix it.

Narsil
Narsil commented on 2024-06-06
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
168 kv_cache[0],
169 kv_cache[1],
170 slots,
171
self.kv_cache_dtype,
Narsil346 days ago

This one isn't necessary here, it should already be inferrable from kv_cache[0].

Narsil
Narsil346 days ago

Happy to help with the rebase btw.

fxmarty
fxmarty commented on 2024-06-14
fxmarty338 days ago

nothing shocking to me for benchmarking!

Conversation is marked as resolved
Show resolved
server/text_generation_server/layers/schema.py
6technique. The format of this JSON should be specified by one or more
7schemas contained here.
8
9
For example, when the KV cache is quantized to FP8-E4M3 (currently only
fxmarty338 days ago

to be precise, on Instinct I think it is E4M3FN that is used

Conversation is marked as resolved
Show resolved
server/text_generation_server/layers/schema.py
13
14from typing import Dict, Optional
15
16
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
17
18
19
class KVCacheQuantSchema(BaseModel):
fxmarty338 days ago

I would use builtin dataclasses and post_init here but it is a matter of taste

mht-sharma
mht-sharma338 days ago

Thanks for the review @Narsil @fxmarty I will rebase and address the comments.

Regarding the format for loading the FP8 scales:

VLLM offers two methods:

  • quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.

  • Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.

VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.

mht-sharma Merge branch 'main' into fp8_kvcache
084de990
mht-sharma rebase and update
81fd601c
mht-sharma fix
fb83e341
mht-sharma fixrs
f0d95b0f
mht-sharma add docs
8a0bb53e
mht-sharma
mht-sharma328 days ago

Thanks for the review @Narsil @fxmarty I will rebase and address the comments.

Regarding the format for loading the FP8 scales:

VLLM offers two methods:

  • quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
  • Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.

VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.

Removed the quantization-param-path altogether: This method is already deprecated in VLLM, based on discussions here: vllm-project/vllm#4532

mht-sharma mht-sharma changed the title [WIP] Add kvcache fp8 support Add kvcache fp8 support 328 days ago
mht-sharma mht-sharma marked this pull request as ready for review 328 days ago
mht-sharma fix style
557e18e0
mht-sharma mht-sharma changed the title Add kvcache fp8 support Add FP8 KVCache support 328 days ago
mht-sharma update port
50806ffe
mht-sharma rename doc
001ec09d
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev328 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

mht-sharma update doc
3cc2f4e9
mht-sharma update launcher
e81c4cf8
mht-sharma
mht-sharma328 days ago

Also I don't see any core logic to actually handle the fp8, are the kernels ready? Is it possible to test/add tests ?

The core logic for handling FP8 is managed by the Paged Attention kernel in VLLM, with the necessary kernel tests. If you have any specific tests in mind, we can discuss them. VLLM includes tests that compare the output with the expected FP8 output, as seen https://github.com/comaniac/vllm/blob/main/tests/models/test_fp8.py. We can add a similar test if required.

mht-sharma update heading
034686b1
mht-sharma revert makefile
3ae62304
mht-sharma mht-sharma requested a review from fxmarty fxmarty 328 days ago
mht-sharma mht-sharma requested a review from Narsil Narsil 328 days ago
fxmarty
fxmarty commented on 2024-06-24
fxmarty328 days ago

Great work, that's cool! I left some general comments, @Narsil may have more about the design

docs/source/basic_tutorials/fp8_kv_cache.md
1# Accelerating Inference with FP8 KV Cache
2
3
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
fxmarty328 days ago
Suggested change
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
Text Generation Inference (TGI) supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks.

It would be worth to explain what is FP8 KV Cache. Readers may not be familiar with it (does it mean attention computation is in fp8? etc)

enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation.

This is kind of vague.

we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks

It would be worth IMO to show numbers / a chart here to get a grasp of what greatly means, in which case there is indeed a speedup, etc

mht-sharma320 days ago

We have a couple of options for numbers and charts

  • Max Batch Total Tokens: The logs display the max batch total tokens, which increase when using the FP8 KV cache. We could create a chart showing the max batch total tokens in both cases (with and without the FP8 KV cache).

  • Throughput: Currently, I have created a custom script using AsyncClient to send 500 requests simultaneously with asyncio.gather. This provides a rough estimate of throughput. @Narsil , do you have any suggestions on calculating throughput more precisely?

docs/source/basic_tutorials/fp8_kv_cache.md
22## Current Hardware Support
23
24* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
25
* AMD GPUs: Supports FP8E4M3.
fxmarty328 days ago

Technically, Instinct supports E4M3FNUZ, not E4M3. https://onnx.ai/onnx/technical/float8.html

docs/source/basic_tutorials/fp8_kv_cache.md
27## FP8 E5M2 KV Cache
28Example usage:
29```
30
text-generation-launcher --model-id <> --kv-cache-dtype fp8_e5m2
fxmarty328 days ago

Let's maybe put a full runnable command with docker run etc? Good inspiration could be https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2
image

docs/source/basic_tutorials/fp8_kv_cache.md
35
36Example usage:
37```
38
text-generation-launcher --model-id <> --kv-cache-dtype fp8
fxmarty328 days ago

same + maybe have a model_id which indeed has scaling factors here.

docs/source/basic_tutorials/fp8_kv_cache.md
39```
40
41### Checkpoint structure for KV scales
42
The FP8 kv cache scaling factors required in the FP16 checkpoints are specified through the .kv_scale parameter present on the `Attention` module, such as:
43
44
```
45
model.layers.0.self_attn.kv_scale < F32
46
model.layers.1.self_attn.kv_scale < F32
47
...
48
```
fxmarty328 days ago

How do we know whether they are for E4M3FNUZ or E4M3FN format?

mht-sharma328 days ago

I guess the current tools Nvidia AMMO and AutoFP8 uses E4M3FN. Currently there is no flag to determine the format (unless ofcourse the weight is quantized). But I could add a check and add a parameter for this in the checkpoint.

mht-sharma327 days ago

Updated reading this from config. See: a7909e6

docs/source/basic_tutorials/fp8_kv_cache.md
51
52Use [AutoFP8](https://github.com/neuralmagic/AutoFP8) with calibration data to generate per-tensor scales for FP8 quantized KV Cache. For more details, see the following example: https://github.com/neuralmagic/AutoFP8/blob/main/example_dataset.py
53
54
TGI provides a utility to extract the FP8 KV cache scales from an `AutoFP8` quantized model and save them to the FP16 model for use with TGI. For more information: <path to script>
fxmarty328 days ago

Todo

docs/source/basic_tutorials/launcher.md
8787 [env: DTYPE=]
8888 [possible values: float16, bfloat16]
8989
90
```
91
## KV_CACHE_DTYPE
92
```shell
93
--kv-cache-dtype <KV_CACHE_DTYPE>
94
[env: KV_CACHE_DTYPE=]
95
[possible values: fp8, fp8_e5m2]
96
fxmarty328 days ago👍 1

Needs a description

server/text_generation_server/models/flash_llama.py
8083 num_kv_heads=model.model.num_key_value_heads,
8184 head_size=model.model.head_size,
82 dtype=dtype,
85
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
fxmarty328 days ago

Things are a bit harder to read if dtype attribute is used to mean the KV cache storage pytorch dtype. For some other models (gemma, idefics), self.dtype is used with an other meaning, being the weights dtype.

server/text_generation_server/utils/weights.py
791 "Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache"
792 )
793
794
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
795
# The multiplication by 2 compensates for the different numeric representation
796
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
797
# After this adjustment, the overall effect is equivalent to the scaling applied without
798
# it on Nvidia GPUs.
799
if SYSTEM == "rocm":
800
kv_scale *= 2.0
fxmarty328 days ago

How do we know whether serialized scales are in E4M3FN or E4M3FNUZ format? I think depending on that, the logic should be different here.

mht-sharma327 days ago (edited 327 days ago)

Updated the code to read the corresponding dtype as kv_cache_torch_dtype from config.

Added the format in the README.md and utility script to add kv_cache_torch_dtype when quantising model

See: a7909e6

examples/fp8_kvcache/README.md
32 Path to save the FP16 model with the kv scales
33```
34
35
## Example usage
36
To extract KV cache scaling factors from a quantized FP8 model and save them to an unquantized FP16 model, use the following command:
37
38
```
39
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct
40
```
fxmarty328 days ago

If we target vllm/tgi intercompatibility, why couldn't we load directly e.g. neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV in TGI? Is it not loadable in vllm?

Also, by quantized FP8 model, do you mean a model whose weights are quantized to fp8? How does it relate to FP8 KV cache? To me to obtain the KV cache scales you would simply need to have calibration data passing through the network & collecting stats on the KV cache.

It feels like the the KV cache scales from a quantized model (like neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV), whose KV cache scales may have been obtained after the weights quantization (?), may have an unnecessary bias due to being inferred from calibration on the quantized model, not the unquantized one.

mht-sharma327 days ago (edited 327 days ago)👍 1

The model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV has both the weights quantized along with KV scales. However, I don't think we can load the FP8 weights in TGI yet. And also one would run an FP16 model with FP8 KV cache so we need such a checkpoint

You are right; it may have an additional bias due to calibration after weight quantization. The bias might have been low, so I couldn't find a noticeable difference during inference.

I have tested two other models generated using scales from Nvidia AMMO (VLLM uses this also), which may not have this bias. The quantizer can accept the quantized format as full precision and provide the model with KV scales. We can extract scales from that.

Here is the link to the models: Llama-2-70b-chat-hf-FP8-KV

I will add these models and provide an example for this.

mht-sharma327 days ago

Updated the example with Nvidia AMMO

examples/fp8_kvcache/README.md
1# FP8 (fp8_e4m3) KV Cache Scaling Factor Extraction Utility
2
3
This utility is designed to extract KV cache scaling factors from a quantized `FP8(fp8_e4m3)` Hugging Face (HF) model. The extracted scaling factors are then saved to the corresponding unquantized HF model, which can be used with Text Generation Inference (TGI).
fxmarty328 days ago

Do you have an example of such model?

mht-sharma327 days ago

Updated and added a model to readme.

docs/source/basic_tutorials/fp8_kv_cache.md
1
# Accelerating Inference with FP8 KV Cache
fxmarty328 days ago

I think it would be worth to have some evaluation metrics for an example model

mht-sharma remove example
f4714a8f
mht-sharma add torch dtype
a7909e6f
mht-sharma add AMMO example
1e6e7db0
mht-sharma updated doc
15b351b4
mht-sharma update launcher
5e38d353
mht-sharma updated doc
bf4db771
mht-sharma updated docs
f34560f7
mht-sharma fix formatting
6d6b0bdc
mht-sharma updated doc
0a5b19a3
danieldk danieldk assigned danieldk danieldk 297 days ago
yao-matrix
yao-matrix commented on 2024-07-29
server/text_generation_server/models/__init__.py
248248
249FP8_KVCACHE_SUPPORTED_MODELS = {
250 "llama",
251
"baichun",
yao-matrix293 days ago

is this a typo of "baichuan"?

mht-sharma293 days ago

Yes. Thanks for pointing it out

server/text_generation_server/utils/weights.py
799 f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'."
800 )
801
802
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
yao-matrix293 days ago

Is it should be "ROCm uses FP8 format with fp8_e4m3fnuz, whereas NVIDIA GPU uses fp8_e4m3fn"?

mht-sharma293 days ago

Yes that's correct

Narsil
Narsil222 days ago

Closing this as we added support for FP8 kv cache support in #2603.

More support is coming (for pre-scaled kv-cache fp8)

Narsil Narsil closed this 222 days ago

Login to write a write a comment.

Login via GitHub

Assignees
Labels
Milestone