llama.cpp
Add Intel Advanced Matrix Extensions (AMX) support to ggml
#8998
Merged

Add Intel Advanced Matrix Extensions (AMX) support to ggml #8998

mingfeima merged 1 commit into master from pr_add_intel_amx_support
mingfeima
mingfeima282 days ago๐Ÿ‘ 7๐ŸŽ‰ 1โค 2

replacement of #7707 to trigger ggml-ci on amx

mingfeima mingfeima marked this pull request as draft 282 days ago
github-actions github-actions added build
github-actions github-actions added ggml
mingfeima mingfeima changed the title Pr add intel amx support Add Intel Advanced Matrix Extensions (AMX) support to ggml 282 days ago
ggerganov
ggerganov282 days ago

To trigger ggml-ci you need to include the string "ggml-ci" somewhere in the commit message. For example: 5ef07e2

mingfeima mingfeima force pushed from 39d84e36 to 0b4de32e 282 days ago
mingfeima mingfeima force pushed from 2c95fa54 to 37ccb9d3 280 days ago
mingfeima mingfeima marked this pull request as ready for review 280 days ago
mingfeima mingfeima requested a review from ggerganov ggerganov 280 days ago
mingfeima mingfeima requested a review from slaren slaren 280 days ago
mingfeima
mingfeima280 days ago๐Ÿ‘€ 1

@ggerganov could you please take a look at this one? I have moved the amx init code from ggml.c to ggml-amx/mmq.cpp according to previous comments.

slaren
slaren commented on 2024-08-14
slaren280 days ago๐Ÿ‘€ 2

Currently, GGML_AMX is always enabled with llama.cpp, and whether AMX is actually used depends entirely on using -march=native with GGML_NATIVE, and building on a machine with AMX support. This is not ideal because it is a departure from the way the other x86 instruction sets are handled, and it doesn't allow for cross-compilation. I think it would be better to handle this in the same way as any other x86 instruction set, and add an explicit compiler option to enable this architecture (-mamx-int8?) when using GGML_AMX without GGML_NATIVE, or let it be enabled automatically by -march=native when using GGML_NATIVE.

mingfeima
mingfeima280 days ago๐ŸŽ‰ 2

@slaren just updated cmake compiler options: -mamx-tile, -mamx-int8 and -mamx-bf16!

ggerganov
ggerganov approved these changes on 2024-08-16
ggerganov ggerganov requested a review from slaren slaren 278 days ago
slaren
slaren commented on 2024-08-16
Conversation is marked as resolved
Show resolved
ggml/src/ggml-amx/mmq.cpp
2486 // pack mat B to vnni format
2487 if (src0->extra == nullptr) {
2488 const size_t row_size_B = get_row_size<type, blck_size>(K);
2489
src0->extra = aligned_alloc(64, N * row_size_B);
slaren278 days ago (edited 278 days ago)

When is this memory freed?

mingfeima274 days ago

right now this memory is not freed until program exits. After I updated the prepacking logic and several rounds of rebasing, the old free logic no longer works.

@slaren does cpu has a mempool structure which allows dynamic alloc from the mempool? Any suggestions?

slaren273 days ago

The CPU backend does not have any mempool that could be used to allocate this memory. But even if it did, this approach would double the memory requirements for the model weights, which I don't think is an acceptable tradeoff in most cases. It is also likely to cause issues when the tensors are not static, such as when using KV cache quantization.

There is a way to implement this kind of tensor data layout transformations with ggml-backend. In short, it is possible to create a type of ggml_backend_buffer that performs these transformations transparently when accessing the tensor data through the set_tensor/get_tensor interfaces. The assumption would be that all tensors allocated in a buffer of this type would use the different layout. Before performing the matrix multiplication, you could check the type of buffer where the tensor is allocated to determine if AMX can be used. This would also require extensive changes to the llama.cpp model loading code to use this buffer type. Alternatively, the existing CPU buffer type could be modified to perform these transformations, but that would also require changing other operations of the CPU backend to be able to use the different layout.

In the past this also has been implemented by adding new tensor data types with the modified layout (eg. Q4_0_4_4 and others for ARM), but I think it would be hard to justify the maintenance cost of adding new quantization types.

My conclusion is that there isn't a simple solution to implement this, and it would need to be done along with significant refactoring of other parts of the code (eg. the llama.cpp model loading code). llama.cpp supports offloading large matrix multiplication to a GPU, which is usually a lot faster than even very optimized CPU code, so I find it hard to justify to spend too much effort into this, but of course contributions are welcome.

mingfeima261 days ago

@slaren is it possible to do an inplace replacement in the original tensor with the AMX required layout when AMX can be used? This way would be very easy to implement, not too much code change is needed. But I am not sure if there is any risk of doing this?

slaren261 days ago

The most significant risk would be breaking the semantics of the way tensors are supposed to work in ggml. For example, after loading a model into tensors and evaluating it, an application may want to read the tensor data to serialize the model back to disk. The tensor data is not hidden from applications, so that is expected to work. Another consequence would be that the tensor data would no longer be in the format that other operations expect. For instance, in models where the token embeddings and output layer use the same tensor, this would break the get_rows operation.

mingfeima245 days ago

changing amx as a ggml-backend to do memory management.

mingfeima235 days ago๐Ÿ‘€ 1

@slaren could you pleas help review this one again? I have updated the implementation with ggml-backend.

mingfeima mingfeima force pushed from a43f8e00 to 47b1a743 245 days ago
mingfeima mingfeima force pushed from c90a43a2 to 28cfc0ff 238 days ago
ggerganov
ggerganov238 days ago

failure 124 means that the run timeout. There is currently a limit of 30 minutes for these runs and in this case it was exceeded.

It's not related to this PR - the CUDA build time has recently increased so this run is timeouting from time to time. I just restarted it to see if it would pass.

mingfeima mingfeima force pushed from b4d38b27 to 7c371fa6 237 days ago
slaren
slaren commented on 2024-09-29
slaren234 days ago

Other than the issues with is_host pointed below, the ggml-backend interface implementation looks good. Unfortunately fixing these issues may result in a small hit in performance since it may cause some additional copies when data needs to be moved between the CPU and AMX backends, and fixing that will require changes to the ggml-backend interface.

The llama.cpp side will probably need some changes. I expect that the current implementation won't work with KV quantization. I cannot test this, but I think changing it so that the AMX buffer type is only used for the weights may work better, while also avoiding the need to use -ngl:

diff --git a/src/llama.cpp b/src/llama.cpp
index b85e8acd..13d70ec1 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3462,8 +3462,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
     }
 #elif defined(GGML_USE_CANN)
     buft = ggml_backend_cann_buffer_type(local_gpu);
-#elif defined(GGML_USE_AMX)
-    buft = ggml_backend_amx_buffer_type();
 #endif

     if (buft == nullptr) {
@@ -6865,7 +6863,14 @@ static bool llm_load_tensors(

     // assign cpu layers
     for (int i = 0; i < i_gpu_start; ++i) {
+    #ifdef GGML_USE_AMX
+        model.buft_layer[i] = {
+            llama_default_buffer_type_cpu(true),
+            ggml_backend_amx_buffer_type()
+        };
+    #else
         model.buft_layer[i] = llama_default_buffer_type_cpu(true);
+    #endif
     }

     if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
@@ -18587,11 +18592,6 @@ struct llama_model_params llama_model_default_params() {
     result.n_gpu_layers = 999;
 #endif

-#ifdef GGML_USE_AMX
-    // by default offload all layers to AMX
-    result.n_gpu_layers = 999;
-#endif
-
     return result;
 }

I also expect that this implementation will have issues when built with a GPU backend such as CUDA that allows the weights to be copied to VRAM when evaluating large batches (>32 tokens), although that could be fixed by implementing conversion back to standard ggml format in ggml_backend_amx_buffer_get_tensor.

Conversation is marked as resolved
Show resolved
ggml/src/ggml-amx.cpp
111}
112
113GGML_CALL static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
114
return true;
slaren234 days ago

Returning true here will cause the CPU to assume that it can use tensors allocated in this backend, however this is not the case for the types that are changed by this backend.

Suggested change
return true;
return false;
mingfeima222 days ago

fixed!

Conversation is marked as resolved
Show resolved
ggml/src/ggml-amx.cpp
slaren234 days ago

The check to use memcpy needs to be more strict, since the src tensor may be a quantized tensor allocated in a buffer with standard layout.

mingfeima222 days ago

fixed, the quantized tensor will call conversion to do packing.

Conversation is marked as resolved
Show resolved
src/llama.cpp
50735087 } else {
50745088 GGML_ASSERT(weight->idx < files.size());
50755089 const auto & file = files.at(weight->idx);
5076 if (ggml_backend_buffer_is_host(cur->buffer)) {
5090
#if defined(GGML_USE_AMX)
5091
const bool can_use_mmap = false;
5092
#else
5093
const bool can_use_mmap = true;
5094
#endif
5095
if (ggml_backend_buffer_is_host(cur->buffer) && can_use_mmap) {
slaren234 days ago

This change should not be necessary.

mingfeima222 days ago

this has been removed.

mingfeima
mingfeima234 days ago

@slaren after changing is_host to false from the AMX backend leads to an fault from ggml_backend_sched_backend_id_from_cur (log attached below). Do you have any insight how to fix it?

llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        AMX KV buffer size =  1024.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
ggml/src/ggml-backend.c:1204: pre-allocated tensor in a backend that cannot run the operation
[New LWP 2746117]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
0x00007f7d1a7205a2 in waitpid () from /lib64/libpthread.so.0
#0  0x00007f7d1a7205a2 in waitpid () from /lib64/libpthread.so.0
#1  0x000000000048a648 in ggml_print_backtrace () at ggml/src/ggml.c:282
282             waitpid(pid, &wstatus, 0);
#2  ggml_abort (file=file@entry=0x6788d4 "ggml/src/ggml-backend.c", line=line@entry=1204, fmt=fmt@entry=0x678c50 "pre-allocated tensor in a backend that cannot run the operation") at ggml/src/ggml.c:309
309         ggml_print_backtrace();
#3  0x00000000004cf025 in ggml_backend_sched_backend_id_from_cur (sched=0x172fd20, tensor=0x5305e10) at ggml/src/ggml-backend.c:1204
1204            GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
#4  0x00000000004d127c in ggml_backend_sched_split_graph (sched=sched@entry=0x172fd20, graph=graph@entry=0x1ada190) at ggml/src/ggml-backend.c:1337
1337                *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
#5  0x00000000004d2dde in ggml_backend_sched_split_graph (graph=0x1ada190, sched=<optimized out>) at ggml/src/ggml-backend.c:1327
1327        if (sched->ctx == NULL) {
#6  ggml_backend_sched_reserve (sched=0x172fd20, measure_graph=0x1ada190) at ggml/src/ggml-backend.c:1992
1992        ggml_backend_sched_split_graph(sched, measure_graph);
#7  0x000000000053204b in llama_new_context_with_model (model=0x1729f30, params=...) at src/llama.cpp:19176
19176               if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
#8  0x000000000060a48d in llama_init_from_gpt_params (params=...) at common/common.cpp:843
843         llama_context * lctx = llama_new_context_with_model(model, cparams);
#9  0x000000000043894d in main (argc=<optimized out>, argv=0x7fffb9a8de18) at examples/main/main.cpp:200
200         llama_init_result llama_init = llama_init_from_gpt_params(params);
[Inferior 1 (process 2746116) detached]
./run_generate_cpu.sh: line 26: 2746116 Aborted                 (core dumped) $PREFIX $main -m ./models/Meta-Llama-3-8B-Instruct-GGUF/$MODEL -t $CORES -n 5 -p "$prompt" --no-mmap
slaren
slaren234 days ago

I think that may be related to KV operations, which should be fixed with the change I suggested before. By making llama_default_buffer_type_offload return the AMX buffer type, it will cause the KV cache to be allocated on an AMX buffer, which is not good. If that doesn't fix it, please add some prints to show the tensor that is causing the error.

mingfeima mingfeima force pushed from 2f37e21d to fc709cfc 222 days ago
mingfeima
mingfeima220 days ago

@slaren could you please help review this one again? just changed ggml_backend_buft_is_host to return false for amx backend.

slaren
slaren commented on 2024-10-14
Conversation is marked as resolved
Show resolved
ggml/src/ggml-backend.cpp
14961496 }
14971497 }
14981498
1499
#ifndef GGML_USE_AMX
14991500
if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
15001501
// since the tensor is pre-allocated, it cannot be moved to another backend
15011502
GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
15021503
}
1504
#endif
slaren219 days ago

I am not sure what is the purpose of this change, but it is not ok to change the behavior of ggml_backend_sched depending on what backends are included in the build. The source of the problem must be solved instead.

mingfeima217 days ago

the issue has been fixed!

Conversation is marked as resolved
Show resolved
src/llama.cpp
1940019416 }
1940119417#endif
1940219418
19419
#if defined(GGML_USE_AMX)
19420
{
19421
ggml_backend_t backend = ggml_backend_amx_init(cparams.n_threads);
19422
if (backend == nullptr) {
19423
LLAMA_LOG_ERROR("%s: failed to initialize AMX backend\n", __func__);
19424
llama_free(ctx);
19425
return nullptr;
19426
}
19427
ctx->backends.push_back(backend);
19428
}
19429
#endif
slaren219 days ago

The AMX backend should implement the new backend reg and device interfaces, and this should be removed. See #9752 for a simple example of how implement these interfaces.

mingfeima217 days ago

implemented with the new backend reg and device interface.

Conversation is marked as resolved
Show resolved
src/llama.cpp
slaren219 days ago

This doesn't need to be changed now, but just as a heads ups, I will change the way this is done completely in a future refactor. I do not have a machine with AMX to test this backend, so it will be important to have CI tests for AMX.

mingfeima217 days ago

covered in #7707 (comment)

mingfeima mingfeima force pushed from fc709cfc to 45451e23 217 days ago
slaren
slaren commented on 2024-10-16
Conversation is marked as resolved
Show resolved
ggml/src/ggml-amx.cpp
310 return ggml_backend_amx_buffer_type();
311
312 GGML_UNUSED(dev);
313
}
314
315
static ggml_backend_buffer_t ggml_backend_amx_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
slaren217 days ago

I think that this backend should not implement this function since it modifies the layout of the tensors. This is mainly intended to be used with mmap, which means that the backend must be able to use the tensors in the standard ggml layout.

mingfeima217 days ago

Oh thanks for the headsup!

slaren
slaren commented on 2024-10-16
Conversation is marked as resolved
Show resolved
ggml/src/ggml-amx.cpp
292 props->type = ggml_backend_amx_device_get_type(dev);
293 ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
294 props->caps = {
295
/* .async = */ false,
slaren217 days ago

Likewise this should be set to false.

mingfeima mingfeima force pushed from 60727752 to 3366c22c 217 days ago
mingfeima mingfeima requested a review from slaren slaren 217 days ago
slaren
slaren approved these changes on 2024-10-17
slaren
slaren217 days ago๐Ÿ‘ 1

Looks good to me, feel free to merge this at any point.

ggerganov
ggerganov216 days ago (edited 216 days ago)

@slaren Thank you for the detailed review.

@mingfeima Remember to squash the commits when merging as explained in the contributing guidelines. Btw, I just restarted the ggml-ci node with AMX instruction set support, so we might want to wait for ggml-ci to run before merging. Will run it on this branch shortly.

Edit: the AMX CI has passed

ggerganov
ggerganov216 days ago๐Ÿ‘€ 2

Would recommend using 4 spaces for indentation for conformance with the rest of the codebase.

mingfeima add amx kernel for gemm
2d3fc54a
mingfeima mingfeima force pushed from d5e8aba7 to 2d3fc54a 216 days ago
mingfeima
mingfeima215 days ago๐Ÿ‘ 2

@ggerganov changed to tab with 4 spaces. also the branch is rebased to squash into one.

ggerganov
ggerganov215 days ago๐Ÿ‘ 3โค 2

Nice, great job! Feel free to merge this - you should have the access to do so.

mingfeima mingfeima merged 60ce97c9 into master 215 days ago
mingfeima
mingfeima215 days ago๐Ÿ‘ 1๐ŸŽ‰ 1

@slaren thanks a lot for your review!

nai-kon
nai-kon191 days ago (edited 190 days ago)

@mingfeima

Thank you for your work. I have a question.
The qtype_has_amx_kernels() returns TRUE with Q4_0 and Q4_1, so does it mean the AMX acceleration is only enabled with Q4_0 and Q4_1 model?
Is there any plan to support Q4_K or more?

qtype_has_amx_kernels

inline bool qtype_has_amx_kernels(const enum ggml_type type) {
    // TODO: fix padding for vnni format
    return (type == GGML_TYPE_Q4_0) ||
        (type == GGML_TYPE_Q4_1);
        //(type == GGML_TYPE_Q8_0) ||
        //(type == GGML_TYPE_Q4_K) ||
        //(type == GGML_TYPE_Q5_K) ||
        //(type == GGML_TYPE_Q6_K) ||
        //(type == GGML_TYPE_IQ4_XS);
}
nai-kon
nai-kon190 days ago

@mingfeima
I have an additional question.
When I tried latest lamma.cpp in my environment, the prompt eval time became 2x faster compared to no AMX supported commit, but the eval time (text generation speed) remained almost unchanged.
(I tried it with gemma-2b-q4_0 on Intel Xeon with AMX.)

In your original PR, it was reported that eval time became 2x faster.

Do you have any idea that the reason of eval time does not become faster?

mingfeima
mingfeima185 days agoโค 2

@nai-kon originally i wrote kernels for

        //(type == GGML_TYPE_Q8_0) ||
        //(type == GGML_TYPE_Q4_K) ||
        //(type == GGML_TYPE_Q5_K) ||
        //(type == GGML_TYPE_Q6_K) ||
        //(type == GGML_TYPE_IQ4_XS);

but later on, these are banned, because i did not figure out how to do the padding with ggml-backend (these formats require additional padding for the packed format). Anyway i suppose it can be done, it's just i don't have spare time for this at the moment.

As for the performance of generation, I suppose this is because the original cpu impl in ggml-quant has also been improved (the old code base has a huge thread sync overhead with pthread, it uses atomic; later on when openmp is used, the overhead is much smaller). Additionally, the current version of AMX kernels are actually slower than my first version, some optimizations are removed to rebase to use ggml-backend. My estimation is that, the current AMX kernels still have ~1/3 gap to optimal.

the work here is my personal interest (not company task), I will get back to this once I have spare time (add back the kernels for other quant types and improve the performance).

nai-kon
nai-kon184 days ago (edited 184 days ago)

Thank you for your reply. It sounds wonderful that the generation speed can still be improved.
I truly appreciate your wonderful personal effort and work.

slaren
slaren183 days ago๐Ÿ‘€ 2

Additionally, the current version of AMX kernels are actually slower than my first version, some optimizations are removed to rebase to use ggml-backend.

Did you find any issue with the ggml-backend interface that forced you to remove these optimizations? The plan is to reintegrate the AMX backend in the CPU backend in the future (#10359), which may eliminate some overheads and allow using the optimized q8_0 quantization functions again.

mingfeima
mingfeima177 days ago

@slaren the major problem that i have with ggml-backend is I didn't figure out how to do padding with the AMX backend (when the packing weight for AMX, e.g. vnni format, has a different memory size with default CPU backend. I run into a couple of issues when integrating AMX backend with ggml-backend. So I just leave the dtypes that does not require padding. Again, I think it should be able to be done elegantly, I just did not have the time to investigate recently.

From the performance wise, TODOs from my side are: fuse the quantization of A into the gemm loop; Q8_K quant method is a reference now (very slow); and so on.

slaren
slaren174 days ago

@mingfeima as far as I can tell, you were already doing everything that's necessary to allow padding (returning the padded size in get_alloc_size is enough). I intend to move the AMX code to the CPU backend and enable all the supported types in #10570.

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone