Enable U8 KV caching in SDPA operator for ARM (#33567)
[About]
This PR enables u8 kv cache precsion for SDPA operator and optimizes the
same with NEON and SVE.
- Improves the performance of OSS master [ where reference
implementation is available ] version by 27%.
- But we are slower by 2.7% when compared with non-quantized f16 cache
precision due to additional overhead of quantization and dequantization
for smaller models like TinyLlama-1.1B for single inference.
- Such performance benefit [from u8 quantization] can be seen only when
the inference is more memory bound. We see speedups around 3-5% when
inferencing LLama-70B int8 quantized model for single Inference case.
- Therefore, even though we achieve a speedup of 27% compared to
reference implementation, we assume the general case to be compute bound
and currently keeping the default as F16 only.
- As models get larger and in multiple batch scenarios, by setting
kv_cache as "u8" we see significant boost at inference level.
| OSS ref impl - u8 | This PR |
|----------|:----------:|
| 10.8 tokens/sec | 13.7 tokens/sec |
Single inference performance on LLAMA2-7B model on 32c graviton machine.
The values are in TPS [ Tokens per second ].
This work is contributed by @ashwins990 & @abhijain1204fujitsu