llama.cpp
CUDA: add FP32 FlashAttention vector kernel
#7188
Merged

CUDA: add FP32 FlashAttention vector kernel #7188

JohannesGaessler
JohannesGaessler1 year ago👍 7🚀 3

This PR adds an FP32 FlashAttention kernel that is very similar to the FP16 kernel. It enables using FlashAttention on NVIDIA GPUs without fast FP16 and without tensor cores. It should also provide a speedup on more recent NVIDIA GPUs for batch size 1 and FP32 precision. I have moved the FP16 and FP32 FlashAttention vector kernels to separate files in order to speed up compilation. I also added a function ggml_backend_cuda_get_device_cc to ggml-cuda.h in order to avoid breaking tests/test-backend-ops on NVIDIA GPUs without tensor cores. Unlike with the FP16 kernel there are no weird issues with arrays of size 1 vs. regular variables.

Performance on 1x P40:

model backend ngl n_batch fa test t/s
llama 8B Q4_0 CUDA 99 1 0 pp 4096 40.17 ± 0.00
llama 8B Q4_0 CUDA 99 1 1 pp 4096 52.69 ± 0.00
llama 8B Q4_0 CUDA 99 2 0 pp 4096 42.57 ± 0.00
llama 8B Q4_0 CUDA 99 2 1 pp 4096 97.77 ± 0.00
llama 8B Q4_0 CUDA 99 4 0 pp 4096 71.13 ± 0.00
llama 8B Q4_0 CUDA 99 4 1 pp 4096 117.34 ± 0.00
llama 8B Q4_0 CUDA 99 8 0 pp 4096 83.96 ± 0.00
llama 8B Q4_0 CUDA 99 8 1 pp 4096 143.52 ± 0.00
llama 8B Q4_0 CUDA 99 16 0 pp 4096 110.77 ± 0.00
llama 8B Q4_0 CUDA 99 16 1 pp 4096 125.14 ± 0.00
llama 8B Q4_0 CUDA 99 32 0 pp 4096 210.72 ± 0.00
llama 8B Q4_0 CUDA 99 32 1 pp 4096 209.30 ± 0.00
llama 8B Q4_0 CUDA 99 64 0 pp 4096 387.33 ± 0.00
llama 8B Q4_0 CUDA 99 64 1 pp 4096 315.04 ± 0.00
llama 8B Q4_0 CUDA 99 128 0 pp 4096 532.72 ± 0.00
llama 8B Q4_0 CUDA 99 128 1 pp 4096 357.28 ± 0.00
llama 8B Q4_0 CUDA 99 256 0 pp 4096 664.68 ± 0.00
llama 8B Q4_0 CUDA 99 256 1 pp 4096 374.81 ± 0.00
llama 8B Q4_0 CUDA 99 512 0 pp 4096 748.74 ± 0.00
llama 8B Q4_0 CUDA 99 512 1 pp 4096 375.69 ± 0.00
llama 8B Q4_0 CUDA 99 1024 0 pp 4096 749.28 ± 0.00
llama 8B Q4_0 CUDA 99 1024 1 pp 4096 375.78 ± 0.00
llama 8B Q4_0 CUDA 99 2048 0 pp 4096 749.41 ± 0.00
llama 8B Q4_0 CUDA 99 2048 1 pp 4096 375.89 ± 0.00
llama 8B Q4_0 CUDA 99 4096 0 pp 4096 749.52 ± 0.00
llama 8B Q4_0 CUDA 99 4096 1 pp 4096 375.90 ± 0.00
JohannesGaessler
JohannesGaessler1 year ago

Fixes #7055 .

JohannesGaessler JohannesGaessler added performance
JohannesGaessler JohannesGaessler added Nvidia GPU
JohannesGaessler JohannesGaessler added Review Complexity : High
slaren
slaren1 year ago👍 2

This happens regularly, but it's never going to be ok to add backend-specific functions to test-backend-ops. Instead, add the necessary checks to the supports_op function in ggml-cuda.

JohannesGaessler JohannesGaessler force pushed from 29e01c3b to de85f908 1 year ago
sorasoras
sorasoras1 year ago (edited 1 year ago)👀 2
  Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
| model                          |       size |     params | backend    | ngl | sm         |         fa | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------: | ---------- | ---------------: |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          1 | pp 512     |   687.77 ± 12.71 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          1 | tg 128     |     34.70 ± 0.30 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          0 | pp 512     |    767.92 ± 1.57 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          0 | tg 128     |     34.36 ± 0.14 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 | none       |          1 | pp 512     |   1511.97 ± 9.18 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 | none       |          1 | tg 128     |     57.05 ± 0.02 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 | none       |          0 | pp 512     |   1773.69 ± 5.63 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 | none       |          0 | tg 128     |     56.31 ± 0.72 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          1 | pp 1024    |    650.39 ± 8.36 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          1 | pp 2048    |    574.70 ± 3.05 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          1 | pp 4096    |    465.18 ± 3.77 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          1 | tg 128     |     35.06 ± 0.06 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          0 | pp 1024    |   760.63 ± 11.81 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          0 | pp 2048    |    726.50 ± 7.10 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          0 | pp 4096    |    669.02 ± 2.67 |
| qwen2 ?B IQ4_XS - 4.25 bpw     |  16.51 GiB |    32.51 B | ROCm       |  99 | none       |          0 | tg 128     |     33.90 ± 0.30 |
build: de85f908 (2834)

  Device 0: Tesla P40, compute capability 6.1, VMM: yes
| model                          |       size |     params | backend    | ngl | sm         |         fa | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------: | ---------- | ---------------: |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          1 | pp 512     |    180.33 ± 0.31 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          1 | tg 128     |     11.35 ± 0.01 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          0 | pp 512     |    201.03 ± 0.31 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          0 | tg 128     |      9.31 ± 0.02 |
| qwen2 13B Q5_K - Small           |   9.33 GiB |    14.17 B | CUDA        |  99 |          1 | pp 512     |    376.66 ± 0.05 |
| qwen2 13B Q5_K - Small           |   9.33 GiB |    14.17 B | CUDA        |  99 |          1 | tg 128     |     22.66 ± 0.03 |
| qwen2 13B Q5_K - Small           |   9.33 GiB |    14.17 B | CUDA        |  99 |          0 | pp 512     |    435.80 ± 0.16 |
| qwen2 13B Q5_K - Small           |   9.33 GiB |    14.17 B | CUDA        |  99 |          0 | tg 128     |     17.88 ± 0.02 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          1 | pp 1024    |    166.17 ± 0.07 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          1 | pp 2048    |    143.93 ± 0.05 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          1 | pp 4096    |    113.58 ± 0.09 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          1 | tg 128     |     11.30 ± 0.00 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          0 | pp 1024    |    196.44 ± 0.25 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          0 | pp 2048    |    189.21 ± 0.23 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          0 | pp 4096    |    177.17 ± 0.21 |
| qwen2 ?B Q4_K - Medium         |  18.34 GiB |    32.51 B | CUDA       |  99 |          0 | tg 128     |      9.31 ± 0.01 |
build: de85f908 (2834)

TG speed up is significant but PP is slower quite a bit, I don't know why.

JohannesGaessler
JohannesGaessler1 year ago

There simply isn't yet a kernel optimized for large batch sizes.

slaren
slaren commented on 2024-05-10
ggml-cuda.cu
2849 if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2850 return true;
2851 }
2852
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
2853
if (ggml_cuda_info().devices[id].cc < CC_VOLTA) {
2854
return false;
2855
}
2856
}
slaren1 year ago👍 2

I don't think it is necessary to check every device here, instead get the context and check only the device for this context. Something like this:

ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
if (ggml_cuda_info().devices[ctx->device].cc < CC_VOLTA) {
    return false;
}
JohannesGaessler JohannesGaessler force pushed from de85f908 to e0d11842 1 year ago
scottmudge
scottmudge1 year ago👍 7🚀 1

Just adding a small data point, with KoboldCPP compiled with this, with a Q8_K 11b model on 2 x 1080 Ti (Pascal) setup, I get:

  • ~20.2 T/s avg (proc + gen) with FP32 FA enabled.
  • ~13.4 T/s avg (proc + gen) with FP32 FA disabled.

So a significant improvement in my case. Whereas with FP16 FA, I saw a decrease. So it definitely has utility for a subset of users.

github-actions
github-actions1 year ago (edited 1 year ago)

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 541 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8664.65ms p(95)=20209.22ms fails=, finish reason: stop=481 truncated=60
  • Prompt processing (pp): avg=106.85tk/s p(95)=491.78tk/s
  • Token generation (tg): avg=32.65tk/s p(95)=46.81tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-fa-no-tc-11 commit=aa9cbd76608e8aacf5e02e9568d935e9c4e9fbfe

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 541 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715463597 --> 1715464221
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 558.36, 558.36, 558.36, 558.36, 558.36, 610.59, 610.59, 610.59, 610.59, 610.59, 629.47, 629.47, 629.47, 629.47, 629.47, 666.6, 666.6, 666.6, 666.6, 666.6, 695.05, 695.05, 695.05, 695.05, 695.05, 697.12, 697.12, 697.12, 697.12, 697.12, 701.22, 701.22, 701.22, 701.22, 701.22, 735.1, 735.1, 735.1, 735.1, 735.1, 737.46, 737.46, 737.46, 737.46, 737.46, 752.09, 752.09, 752.09, 752.09, 752.09, 781.28, 781.28, 781.28, 781.28, 781.28, 813.34, 813.34, 813.34, 813.34, 813.34, 837.77, 837.77, 837.77, 837.77, 837.77, 835.99, 835.99, 835.99, 835.99, 835.99, 824.19, 824.19, 824.19, 824.19, 824.19, 834.31, 834.31, 834.31, 834.31, 834.31, 841.64, 841.64, 841.64, 841.64, 841.64, 854.48, 854.48, 854.48, 854.48, 854.48, 852.15, 852.15, 852.15, 852.15, 852.15, 852.24, 852.24, 852.24, 852.24, 852.24, 858.53, 858.53, 858.53, 858.53, 858.53, 859.4, 859.4, 859.4, 859.4, 859.4, 859.81, 859.81, 859.81, 859.81, 859.81, 858.76, 858.76, 858.76, 858.76, 858.76, 860.3, 860.3, 860.3, 860.3, 860.3, 856.9, 856.9, 856.9, 856.9, 856.9, 870.98, 870.98, 870.98, 870.98, 870.98, 866.95, 866.95, 866.95, 866.95, 866.95, 866.21, 866.21, 866.21, 866.21, 866.21, 866.35, 866.35, 866.35, 866.35, 866.35, 871.11, 871.11, 871.11, 871.11, 871.11, 870.16, 870.16, 870.16, 870.16, 870.16, 870.69, 870.69, 870.69, 870.69, 870.69, 871.67, 871.67, 871.67, 871.67, 871.67, 881.92, 881.92, 881.92, 881.92, 881.92, 887.42, 887.42, 887.42, 887.42, 887.42, 849.43, 849.43, 849.43, 849.43, 849.43, 846.9, 846.9, 846.9, 846.9, 846.9, 846.49, 846.49, 846.49, 846.49, 846.49, 851.96, 851.96, 851.96, 851.96, 851.96, 852.52, 852.52, 852.52, 852.52, 852.52, 832.68, 832.68, 832.68, 832.68, 832.68, 818.88, 818.88, 818.88, 818.88, 818.88, 813.54, 813.54, 813.54, 813.54, 813.54, 812.92, 812.92, 812.92, 812.92, 812.92, 810.82, 810.82, 810.82, 810.82, 810.82, 812.09, 812.09, 812.09, 812.09, 812.09, 815.63, 815.63, 815.63, 815.63, 815.63, 814.74, 814.74, 814.74, 814.74, 814.74, 814.73, 814.73, 814.73, 814.73, 814.73, 815.09, 815.09, 815.09, 815.09, 815.09, 817.5, 817.5, 817.5, 817.5, 817.5, 821.17, 821.17, 821.17, 821.17, 821.17, 820.54, 820.54, 820.54, 820.54, 820.54, 818.59, 818.59, 818.59, 818.59, 818.59, 820.36, 820.36, 820.36, 820.36, 820.36, 821.02, 821.02, 821.02, 821.02, 821.02, 821.66, 821.66, 821.66, 821.66, 821.66, 821.42, 821.42, 821.42, 821.42, 821.42, 822.64, 822.64, 822.64, 822.64, 822.64, 823.45, 823.45, 823.45]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 541 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715463597 --> 1715464221
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.07, 45.07, 45.07, 45.07, 45.07, 38.87, 38.87, 38.87, 38.87, 38.87, 31.59, 31.59, 31.59, 31.59, 31.59, 30.53, 30.53, 30.53, 30.53, 30.53, 31.39, 31.39, 31.39, 31.39, 31.39, 32.2, 32.2, 32.2, 32.2, 32.2, 33.29, 33.29, 33.29, 33.29, 33.29, 34.13, 34.13, 34.13, 34.13, 34.13, 34.16, 34.16, 34.16, 34.16, 34.16, 33.84, 33.84, 33.84, 33.84, 33.84, 33.94, 33.94, 33.94, 33.94, 33.94, 34.08, 34.08, 34.08, 34.08, 34.08, 33.56, 33.56, 33.56, 33.56, 33.56, 33.28, 33.28, 33.28, 33.28, 33.28, 31.96, 31.96, 31.96, 31.96, 31.96, 32.21, 32.21, 32.21, 32.21, 32.21, 32.56, 32.56, 32.56, 32.56, 32.56, 32.34, 32.34, 32.34, 32.34, 32.34, 31.89, 31.89, 31.89, 31.89, 31.89, 31.61, 31.61, 31.61, 31.61, 31.61, 31.44, 31.44, 31.44, 31.44, 31.44, 31.41, 31.41, 31.41, 31.41, 31.41, 31.56, 31.56, 31.56, 31.56, 31.56, 31.26, 31.26, 31.26, 31.26, 31.26, 31.49, 31.49, 31.49, 31.49, 31.49, 31.63, 31.63, 31.63, 31.63, 31.63, 31.53, 31.53, 31.53, 31.53, 31.53, 30.77, 30.77, 30.77, 30.77, 30.77, 30.74, 30.74, 30.74, 30.74, 30.74, 31.01, 31.01, 31.01, 31.01, 31.01, 31.11, 31.11, 31.11, 31.11, 31.11, 31.19, 31.19, 31.19, 31.19, 31.19, 31.32, 31.32, 31.32, 31.32, 31.32, 31.47, 31.47, 31.47, 31.47, 31.47, 31.31, 31.31, 31.31, 31.31, 31.31, 31.3, 31.3, 31.3, 31.3, 31.3, 31.24, 31.24, 31.24, 31.24, 31.24, 31.16, 31.16, 31.16, 31.16, 31.16, 31.37, 31.37, 31.37, 31.37, 31.37, 31.5, 31.5, 31.5, 31.5, 31.5, 31.62, 31.62, 31.62, 31.62, 31.62, 31.73, 31.73, 31.73, 31.73, 31.73, 31.54, 31.54, 31.54, 31.54, 31.54, 31.23, 31.23, 31.23, 31.23, 31.23, 31.19, 31.19, 31.19, 31.19, 31.19, 29.95, 29.95, 29.95, 29.95, 29.95, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 30.1, 30.1, 30.1, 30.1, 30.1, 30.19, 30.19, 30.19, 30.19, 30.19, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.21, 30.21, 30.21, 30.21, 30.21, 30.12, 30.12, 30.12, 30.12, 30.12, 30.06, 30.06, 30.06, 30.06, 30.06, 30.07, 30.07, 30.07, 30.07, 30.07, 30.18, 30.18, 30.18, 30.18, 30.18, 30.3, 30.3, 30.3, 30.3, 30.3, 30.42, 30.42, 30.42, 30.42, 30.42, 30.47, 30.47, 30.47, 30.47, 30.47, 30.52, 30.52, 30.52]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 541 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715463597 --> 1715464221
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18, 0.18, 0.18, 0.18, 0.18, 0.35, 0.35, 0.35, 0.35, 0.35, 0.3, 0.3, 0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.27, 0.27, 0.27, 0.27, 0.27, 0.28, 0.28, 0.28, 0.28, 0.28, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.35, 0.35, 0.35, 0.35, 0.35, 0.26, 0.26, 0.26, 0.26, 0.26, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.27, 0.27, 0.27, 0.27, 0.27, 0.17, 0.17, 0.17, 0.17, 0.17, 0.08, 0.08, 0.08, 0.08, 0.08, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.35, 0.35, 0.35, 0.35, 0.35, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.53, 0.53, 0.53, 0.53, 0.53, 0.09, 0.09, 0.09, 0.09, 0.09, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.31, 0.31, 0.31, 0.31, 0.31, 0.19, 0.19, 0.19, 0.19, 0.19, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 541 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715463597 --> 1715464221
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0]
                    
Loading

JohannesGaessler CUDA: add FP32 FlashAttention vector kernel
bbeb952a
JohannesGaessler fixup! CUDA: add FP32 FlashAttention vector kernel
41f5f3a4
JohannesGaessler JohannesGaessler force pushed from e0d11842 to 41f5f3a4 1 year ago
JohannesGaessler fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
f3c3eafa
JohannesGaessler fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
aa9cbd76
JohannesGaessler
JohannesGaessler1 year ago

I don't have any ALiBi models set up for testing but according to tests/test-backend-ops the implementation works correctly.

slaren
slaren approved these changes on 2024-05-12
JohannesGaessler JohannesGaessler merged dc685be4 into master 1 year ago
VinnyG9
VinnyG9291 days ago

hi, i get an error when trying to run with -fa on my p100 is support dropped?

JohannesGaessler
JohannesGaessler291 days ago

Pascal is still supported, make an issue.

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone