Fixes #7055 .
This happens regularly, but it's never going to be ok to add backend-specific functions to test-backend-ops
. Instead, add the necessary checks to the supports_op
function in ggml-cuda.
Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
| model | size | params | backend | ngl | sm | fa | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------: | ---------- | ---------------: |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 1 | pp 512 | 687.77 ± 12.71 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 1 | tg 128 | 34.70 ± 0.30 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 0 | pp 512 | 767.92 ± 1.57 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 0 | tg 128 | 34.36 ± 0.14 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | ROCm | 99 | none | 1 | pp 512 | 1511.97 ± 9.18 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | ROCm | 99 | none | 1 | tg 128 | 57.05 ± 0.02 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | ROCm | 99 | none | 0 | pp 512 | 1773.69 ± 5.63 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | ROCm | 99 | none | 0 | tg 128 | 56.31 ± 0.72 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 1 | pp 1024 | 650.39 ± 8.36 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 1 | pp 2048 | 574.70 ± 3.05 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 1 | pp 4096 | 465.18 ± 3.77 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 1 | tg 128 | 35.06 ± 0.06 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 0 | pp 1024 | 760.63 ± 11.81 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 0 | pp 2048 | 726.50 ± 7.10 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 0 | pp 4096 | 669.02 ± 2.67 |
| qwen2 ?B IQ4_XS - 4.25 bpw | 16.51 GiB | 32.51 B | ROCm | 99 | none | 0 | tg 128 | 33.90 ± 0.30 |
build: de85f908 (2834)
Device 0: Tesla P40, compute capability 6.1, VMM: yes
| model | size | params | backend | ngl | sm | fa | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------: | ---------- | ---------------: |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 1 | pp 512 | 180.33 ± 0.31 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 1 | tg 128 | 11.35 ± 0.01 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 0 | pp 512 | 201.03 ± 0.31 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 0 | tg 128 | 9.31 ± 0.02 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | CUDA | 99 | 1 | pp 512 | 376.66 ± 0.05 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | CUDA | 99 | 1 | tg 128 | 22.66 ± 0.03 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | CUDA | 99 | 0 | pp 512 | 435.80 ± 0.16 |
| qwen2 13B Q5_K - Small | 9.33 GiB | 14.17 B | CUDA | 99 | 0 | tg 128 | 17.88 ± 0.02 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 1 | pp 1024 | 166.17 ± 0.07 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 1 | pp 2048 | 143.93 ± 0.05 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 1 | pp 4096 | 113.58 ± 0.09 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 1 | tg 128 | 11.30 ± 0.00 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 0 | pp 1024 | 196.44 ± 0.25 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 0 | pp 2048 | 189.21 ± 0.23 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 0 | pp 4096 | 177.17 ± 0.21 |
| qwen2 ?B Q4_K - Medium | 18.34 GiB | 32.51 B | CUDA | 99 | 0 | tg 128 | 9.31 ± 0.01 |
build: de85f908 (2834)
TG speed up is significant but PP is slower quite a bit, I don't know why.
There simply isn't yet a kernel optimized for large batch sizes.
2849 | if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) { | ||
2850 | return true; | ||
2851 | } | ||
2852 | for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { | ||
2853 | if (ggml_cuda_info().devices[id].cc < CC_VOLTA) { | ||
2854 | return false; | ||
2855 | } | ||
2856 | } |
I don't think it is necessary to check every device here, instead get the context and check only the device for this context. Something like this:
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
if (ggml_cuda_info().devices[ctx->device].cc < CC_VOLTA) {
return false;
}
Just adding a small data point, with KoboldCPP compiled with this, with a Q8_K 11b model on 2 x 1080 Ti (Pascal) setup, I get:
So a significant improvement in my case. Whereas with FP16 FA, I saw a decrease. So it definitely has utility for a subset of users.
📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2
-q4_0
: 541 iterations 🚀
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 541 iterations"
y-axis "llamacpp:prompt_tokens_seconds"
x-axis "llamacpp:prompt_tokens_seconds" 1715463597 --> 1715464221
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 558.36, 558.36, 558.36, 558.36, 558.36, 610.59, 610.59, 610.59, 610.59, 610.59, 629.47, 629.47, 629.47, 629.47, 629.47, 666.6, 666.6, 666.6, 666.6, 666.6, 695.05, 695.05, 695.05, 695.05, 695.05, 697.12, 697.12, 697.12, 697.12, 697.12, 701.22, 701.22, 701.22, 701.22, 701.22, 735.1, 735.1, 735.1, 735.1, 735.1, 737.46, 737.46, 737.46, 737.46, 737.46, 752.09, 752.09, 752.09, 752.09, 752.09, 781.28, 781.28, 781.28, 781.28, 781.28, 813.34, 813.34, 813.34, 813.34, 813.34, 837.77, 837.77, 837.77, 837.77, 837.77, 835.99, 835.99, 835.99, 835.99, 835.99, 824.19, 824.19, 824.19, 824.19, 824.19, 834.31, 834.31, 834.31, 834.31, 834.31, 841.64, 841.64, 841.64, 841.64, 841.64, 854.48, 854.48, 854.48, 854.48, 854.48, 852.15, 852.15, 852.15, 852.15, 852.15, 852.24, 852.24, 852.24, 852.24, 852.24, 858.53, 858.53, 858.53, 858.53, 858.53, 859.4, 859.4, 859.4, 859.4, 859.4, 859.81, 859.81, 859.81, 859.81, 859.81, 858.76, 858.76, 858.76, 858.76, 858.76, 860.3, 860.3, 860.3, 860.3, 860.3, 856.9, 856.9, 856.9, 856.9, 856.9, 870.98, 870.98, 870.98, 870.98, 870.98, 866.95, 866.95, 866.95, 866.95, 866.95, 866.21, 866.21, 866.21, 866.21, 866.21, 866.35, 866.35, 866.35, 866.35, 866.35, 871.11, 871.11, 871.11, 871.11, 871.11, 870.16, 870.16, 870.16, 870.16, 870.16, 870.69, 870.69, 870.69, 870.69, 870.69, 871.67, 871.67, 871.67, 871.67, 871.67, 881.92, 881.92, 881.92, 881.92, 881.92, 887.42, 887.42, 887.42, 887.42, 887.42, 849.43, 849.43, 849.43, 849.43, 849.43, 846.9, 846.9, 846.9, 846.9, 846.9, 846.49, 846.49, 846.49, 846.49, 846.49, 851.96, 851.96, 851.96, 851.96, 851.96, 852.52, 852.52, 852.52, 852.52, 852.52, 832.68, 832.68, 832.68, 832.68, 832.68, 818.88, 818.88, 818.88, 818.88, 818.88, 813.54, 813.54, 813.54, 813.54, 813.54, 812.92, 812.92, 812.92, 812.92, 812.92, 810.82, 810.82, 810.82, 810.82, 810.82, 812.09, 812.09, 812.09, 812.09, 812.09, 815.63, 815.63, 815.63, 815.63, 815.63, 814.74, 814.74, 814.74, 814.74, 814.74, 814.73, 814.73, 814.73, 814.73, 814.73, 815.09, 815.09, 815.09, 815.09, 815.09, 817.5, 817.5, 817.5, 817.5, 817.5, 821.17, 821.17, 821.17, 821.17, 821.17, 820.54, 820.54, 820.54, 820.54, 820.54, 818.59, 818.59, 818.59, 818.59, 818.59, 820.36, 820.36, 820.36, 820.36, 820.36, 821.02, 821.02, 821.02, 821.02, 821.02, 821.66, 821.66, 821.66, 821.66, 821.66, 821.42, 821.42, 821.42, 821.42, 821.42, 822.64, 822.64, 822.64, 822.64, 822.64, 823.45, 823.45, 823.45]
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 541 iterations"
y-axis "llamacpp:predicted_tokens_seconds"
x-axis "llamacpp:predicted_tokens_seconds" 1715463597 --> 1715464221
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.07, 45.07, 45.07, 45.07, 45.07, 38.87, 38.87, 38.87, 38.87, 38.87, 31.59, 31.59, 31.59, 31.59, 31.59, 30.53, 30.53, 30.53, 30.53, 30.53, 31.39, 31.39, 31.39, 31.39, 31.39, 32.2, 32.2, 32.2, 32.2, 32.2, 33.29, 33.29, 33.29, 33.29, 33.29, 34.13, 34.13, 34.13, 34.13, 34.13, 34.16, 34.16, 34.16, 34.16, 34.16, 33.84, 33.84, 33.84, 33.84, 33.84, 33.94, 33.94, 33.94, 33.94, 33.94, 34.08, 34.08, 34.08, 34.08, 34.08, 33.56, 33.56, 33.56, 33.56, 33.56, 33.28, 33.28, 33.28, 33.28, 33.28, 31.96, 31.96, 31.96, 31.96, 31.96, 32.21, 32.21, 32.21, 32.21, 32.21, 32.56, 32.56, 32.56, 32.56, 32.56, 32.34, 32.34, 32.34, 32.34, 32.34, 31.89, 31.89, 31.89, 31.89, 31.89, 31.61, 31.61, 31.61, 31.61, 31.61, 31.44, 31.44, 31.44, 31.44, 31.44, 31.41, 31.41, 31.41, 31.41, 31.41, 31.56, 31.56, 31.56, 31.56, 31.56, 31.26, 31.26, 31.26, 31.26, 31.26, 31.49, 31.49, 31.49, 31.49, 31.49, 31.63, 31.63, 31.63, 31.63, 31.63, 31.53, 31.53, 31.53, 31.53, 31.53, 30.77, 30.77, 30.77, 30.77, 30.77, 30.74, 30.74, 30.74, 30.74, 30.74, 31.01, 31.01, 31.01, 31.01, 31.01, 31.11, 31.11, 31.11, 31.11, 31.11, 31.19, 31.19, 31.19, 31.19, 31.19, 31.32, 31.32, 31.32, 31.32, 31.32, 31.47, 31.47, 31.47, 31.47, 31.47, 31.31, 31.31, 31.31, 31.31, 31.31, 31.3, 31.3, 31.3, 31.3, 31.3, 31.24, 31.24, 31.24, 31.24, 31.24, 31.16, 31.16, 31.16, 31.16, 31.16, 31.37, 31.37, 31.37, 31.37, 31.37, 31.5, 31.5, 31.5, 31.5, 31.5, 31.62, 31.62, 31.62, 31.62, 31.62, 31.73, 31.73, 31.73, 31.73, 31.73, 31.54, 31.54, 31.54, 31.54, 31.54, 31.23, 31.23, 31.23, 31.23, 31.23, 31.19, 31.19, 31.19, 31.19, 31.19, 29.95, 29.95, 29.95, 29.95, 29.95, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 29.96, 30.1, 30.1, 30.1, 30.1, 30.1, 30.19, 30.19, 30.19, 30.19, 30.19, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.26, 30.21, 30.21, 30.21, 30.21, 30.21, 30.12, 30.12, 30.12, 30.12, 30.12, 30.06, 30.06, 30.06, 30.06, 30.06, 30.07, 30.07, 30.07, 30.07, 30.07, 30.18, 30.18, 30.18, 30.18, 30.18, 30.3, 30.3, 30.3, 30.3, 30.3, 30.42, 30.42, 30.42, 30.42, 30.42, 30.47, 30.47, 30.47, 30.47, 30.47, 30.52, 30.52, 30.52]
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 541 iterations"
y-axis "llamacpp:kv_cache_usage_ratio"
x-axis "llamacpp:kv_cache_usage_ratio" 1715463597 --> 1715464221
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18, 0.18, 0.18, 0.18, 0.18, 0.35, 0.35, 0.35, 0.35, 0.35, 0.3, 0.3, 0.3, 0.3, 0.3, 0.15, 0.15, 0.15, 0.15, 0.15, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.27, 0.27, 0.27, 0.27, 0.27, 0.28, 0.28, 0.28, 0.28, 0.28, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.35, 0.35, 0.35, 0.35, 0.35, 0.26, 0.26, 0.26, 0.26, 0.26, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.27, 0.27, 0.27, 0.27, 0.27, 0.17, 0.17, 0.17, 0.17, 0.17, 0.08, 0.08, 0.08, 0.08, 0.08, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.35, 0.35, 0.35, 0.35, 0.35, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.49, 0.53, 0.53, 0.53, 0.53, 0.53, 0.09, 0.09, 0.09, 0.09, 0.09, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.22, 0.22, 0.22, 0.22, 0.22, 0.31, 0.31, 0.31, 0.31, 0.31, 0.19, 0.19, 0.19, 0.19, 0.19, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18]
---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
duration=10m 541 iterations"
y-axis "llamacpp:requests_processing"
x-axis "llamacpp:requests_processing" 1715463597 --> 1715464221
line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0]
I don't have any ALiBi models set up for testing but according to tests/test-backend-ops
the implementation works correctly.
hi, i get an error when trying to run with -fa on my p100 is support dropped?
Pascal is still supported, make an issue.
Login to write a write a comment.
This PR adds an FP32 FlashAttention kernel that is very similar to the FP16 kernel. It enables using FlashAttention on NVIDIA GPUs without fast FP16 and without tensor cores. It should also provide a speedup on more recent NVIDIA GPUs for batch size 1 and FP32 precision. I have moved the FP16 and FP32 FlashAttention vector kernels to separate files in order to speed up compilation. I also added a function
ggml_backend_cuda_get_device_cc
toggml-cuda.h
in order to avoid breakingtests/test-backend-ops
on NVIDIA GPUs without tensor cores. Unlike with the FP16 kernel there are no weird issues with arrays of size 1 vs. regular variables.Performance on 1x P40: