llama.cpp
ggml-quants : ternary packing for TriLMs and BitNet b1.58
#8151
Merged

ggml-quants : ternary packing for TriLMs and BitNet b1.58 #8151

compilade merged 33 commits into master from compilade/bitnet-ternary
compilade
compilade1 year ago (edited 348 days ago)👍 28🎉 12❤ 11🚀 8

This adds 1.6875 bpw and 2.0625 bpw quant types for TriLMs and BitNet b1.58 models. For now, these are named TQ1_0 and TQ2_0, respectively.
I had given glimpses of this idea starting from #7931 (comment).

The 1.6875 bpw type mostly relies on the fact that 35 == 243 < 256 == 28 to pack 5 trits per byte.

(I also made a blog post about ternary packing in an attempt to explain the core idea a bit more (storing the values in fixed-point to extract the most significant digit first with multiplications))

Huge thanks to @Eddie-Wang1120, who motivated this by adding initial BitNet b1.58 support in #7931.

How to try it

Using TriLM models is the easiest because all of their models have row sizes divisible by 256.

Important

To quantize the token embeddings and the output tensor to Q4_K and Q6_K, you need to use llama-quantize on the model files produced by convert_hf_to_gguf.py --outtype tq1_0 (and also for tq2_0). Otherwise these two tensors are kept as f16 and are responsible for most of the size of the models.

$ python3 convert_hf_to_gguf.py /path/to/TriLM_3.9B_Unpacked/ --outfile /somewhere/TriLM-3.9B-TQ1_0-big.gguf --outtype tq1_0
$ ./build/bin/llama-quantize /somewhere/TriLM-3.9B-TQ1_0-big.gguf /somewhere/TriLM-3.9B-TQ1_0.gguf tq1_0

If you want to try TQ2_0, which is faster (but bigger) than TQ1_0 on compute-bound hardware, you can replace tq1_0 with tq2_0 in the above example, but it's also possible to convert from the TQ1_0 model file.

The two ternary formats hold the same values, so round-trip quantizing between the two should result in the same files.

$ ./build/bin/llama-quantize --allow-requantize /somewhere/TriLM-3.9B-TQ1_0.gguf /somewhere/TriLM-3.9B-TQ2_0.gguf tq2_0

Speed

TQ2_0 is twice as fast as Q4_K on my laptop. It's the fastest quant on compute-bound AVX2-capable computers.

This is a table of the float32-equivalent throughput of the vec_dot_q operation for each of these quant types.

CPU F16 Q8_0 Q4_K Q2_K TQ1_0 TQ2_0
Intel Core m3-8100Y (AVX2) 30.60 GB/s 67.03 GB/s 64.17 GB/s 81.73 GB/s 70.31 GB/s 141.83 GB/s
Arm Cortex A72 (NEON) 3.84 GB/s 9.51 GB/s 9.26 GB/s 9.79 GB/s 11.81 GB/s 15.78 GB/s
Arm Cortex A53 (NEON) 4.30 GB/s 5.87 GB/s 5.76 GB/s 5.84 GB/s 8.97 GB/s 10.29 GB/s
AWS t4g (NEON) 8.69 GB/s 22.35 GB/s 25.34 GB/s 22.84 GB/s 33.34 GB/s 44.80 GB/s
AWS t4g (DOTPROD) 49.17 GB/s 42.63 GB/s 45.40 GB/s 29.84 GB/s 40.44 GB/s 65.76 GB/s

From this, it's easy to see that TQ1_0 is usually slightly faster than Q4_K, and that TQ2_0 is by far the fastest quant on AVX2.

Note

There might be a way to make a similar type as TQ2_0 like some sort of Q2_1, which could be almost as fast but still usable by non-ternary models, but this will probably require something like LQER to help with keeping some precision.

Raw data (click to expand)

Intel Core m3-8100Y:

$ for t in bf16 f16 q8_0 q4_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 10000000 --type "$t"; done
bf16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      4.28
      avg cycles/32 vals   :      4.72
      float32 throughput   :     37.89 GB/s
      quantized throughput :     18.95 GB/s

f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      5.52
      avg cycles/32 vals   :      5.93
      float32 throughput   :     30.60 GB/s
      quantized throughput :     15.30 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      2.27
      avg cycles/32 vals   :      2.56
      float32 throughput   :     67.03 GB/s
      quantized throughput :     17.81 GB/s

q4_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      3.04
      avg cycles/32 vals   :      3.38
      float32 throughput   :     52.20 GB/s
      quantized throughput :      7.34 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      2.22
      avg cycles/32 vals   :      2.61
      float32 throughput   :     64.17 GB/s
      quantized throughput :      9.02 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      1.77
      avg cycles/32 vals   :      1.99
      float32 throughput   :     81.73 GB/s
      quantized throughput :      6.70 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      2.12
      avg cycles/32 vals   :      2.33
      float32 throughput   :     70.31 GB/s
      quantized throughput :      3.71 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.85
      avg cycles/32 vals   :      0.97
      float32 throughput   :    141.83 GB/s
      quantized throughput :      9.14 GB/s

Arm Cortex A72 (Raspberry Pi 4):

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done                                                                                        
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      3.84 GB/s
      quantized throughput :      1.92 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.51 GB/s
      quantized throughput :      2.53 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.26 GB/s
      quantized throughput :      1.30 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.79 GB/s
      quantized throughput :      0.80 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     11.81 GB/s
      quantized throughput :      0.62 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     15.78 GB/s
      quantized throughput :      1.02 GB/s

Arm Cortex A53 (Some Android phone from 2017):

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      4.30 GB/s
      quantized throughput :      2.15 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      5.87 GB/s
      quantized throughput :      1.56 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      5.76 GB/s
      quantized throughput :      0.81 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      5.84 GB/s
      quantized throughput :      0.48 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      8.97 GB/s
      quantized throughput :      0.47 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     10.29 GB/s
      quantized throughput :      0.66 GB/s

AWS t4g.small instance (Arm Neoverse N1) using NEON:

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./bin/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      8.69 GB/s
      quantized throughput :      4.35 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     22.35 GB/s
      quantized throughput :      5.94 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     25.34 GB/s
      quantized throughput :      3.56 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     22.84 GB/s
      quantized throughput :      1.87 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     33.34 GB/s
      quantized throughput :      1.76 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     44.80 GB/s
      quantized throughput :      2.89 GB/s

AWS t4g.small (Arm Neoverse N1) with -march=native:

$ for t in f16 q8_0 q4_K q2_K tq1_0 tq2_0; do ./tests/test-quantize-perf --op vec_dot_q -i 2000000 --type "$t"; done
f16
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     49.17 GB/s
      quantized throughput :     24.59 GB/s

q8_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     42.63 GB/s
      quantized throughput :     11.32 GB/s

q4_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     45.40 GB/s
      quantized throughput :      6.38 GB/s

q2_K
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     29.84 GB/s
      quantized throughput :      2.45 GB/s

tq1_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     40.44 GB/s
      quantized throughput :      2.13 GB/s

tq2_0
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :     65.76 GB/s
      quantized throughput :      4.24 GB/s

Size

The token embeddings are kept at Q4_K and the output projection at Q6_K, which means the smaller models might be slightly bigger than 2 bits per weight.

All of the TriLM models should work, because their row sizes are multiples of 256. I did not try them all yet, but those I tried are in the table below.

The BitNet b1.58 models from the 1bitLLM team however are not all compatible; only the 700M model has dimensions divisible by 256. The others are not supported (yet), unless when padding them.

Model F16 TQ1_0 TQ2_0
https://huggingface.co/1bitLLM/bitnet_b1_58-large (728.84 M) 1391.26 MiB 176.65 MiB 207.03 MiB
https://huggingface.co/SpectraSuite/TriLM_390M_Unpacked 750.39 MiB 128.04 MiB 140.98 MiB
https://huggingface.co/SpectraSuite/TriLM_1.5B_Unpacked 2892.09 MiB 401.54 MiB 460.04 MiB
https://huggingface.co/SpectraSuite/TriLM_2.4B_Unpacked 4696.86 MiB 603.59 MiB 703.26 MiB
https://huggingface.co/SpectraSuite/TriLM_3.9B_Unpacked 7616.43 MiB 948.16 MiB 1112.70 MiB

Note

The 1.3B BitNet b1.58 model has a FFN size of 5460 which factors into 2 2 3 5 7 13, which is not convenient for any block-wise types based on powers of 2, so these tensors are kept as F16. My hypothesis is that 5460 was a typo for 5440 (factors into 2 2 2 2 2 2 5 17), but it was kept for some reason, and reproduced by the 1bitLLM team. If anyone training ternary models reads this, PLEASE DON'T USE 5460 FOR THE FFN SIZE! Please use multiples of 256 for your row sizes.

Perplexity

Quality seems good. I don't have a powerful machine, so my tests only include the first 16 chunks of wikitext-2-raw with https://huggingface.co/SpectraSuite/TriLM_390M_Unpacked.

The tests below use Q4_K token embeddings and Q6_K output tensor for TQ1_0 and TQ2_0, while F16 token embeddings and output tensor is used in TQ1_0_L and TQ2_0_L.

chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p
TQ1_0 16 23.6336 ± 1.0765 0.00463 ± 0.00141 0.00187 ± 0.00002 0.860 ± 0.020 % 97.279 ± 0.255 %
TQ2_0 16 23.6336 ± 1.0765 0.00463 ± 0.00141 0.00187 ± 0.00002 0.860 ± 0.020 % 97.279 ± 0.255 %
TQ1_0_L 16 23.5758 ± 1.0746 0.00218 ± 0.00112 0.00034 ± 0.00001 0.405 ± 0.012 % 98.971 ± 0.158 %
TQ2_0_L 16 23.5758 ± 1.0746 0.00218 ± 0.00112 0.00034 ± 0.00001 0.405 ± 0.012 % 98.971 ± 0.158 %

From this it seems like there is no significant quality loss for the ternary quants for TriLM models (I think the difference with pure f16 comes from the 8-bit activations), and that TQ1_0 and TQ2_0 are completely equivalent in quality (and they should be, because lossless conversion between the two is possible).

Structure of TQ1_0

This type relies on the fact that 3^5 == 243 < 256 == 2^8.

In a block of 256 elements, there are 240 elements encoded in 5 elements per byte, while the last 16 elements are encoded in 4 elements per byte.

This means (240 / 5) + (16 / 4) == 48 + 4 == 52 bytes are used to pack 256 ternary weights (this is 1.625 bits per weight).

But there is also one float16 scale per block, so the size of a block is 54 bytes making it a 1.6875 bpw type. Even though it's not ideal, this is still 1.6875 / (log(3) / log(2)) ≈ 94% of the best ternary packing efficiency.

In the table below I'm describing the order of the elements within the bytes. I'm using ranges to make this shorter, with the notation start..end where the start is inclusive and the end is exclusive. (So 0..3 is {0, 1, 2})

Read this as if the ranges of a row are zipped together. A byte never contains more than 5 ternary values.

The ternary values are stored unsigned, so {-1, 0, 1} is stored as {0, 1, 2}.

byte x * 3-1 x * 3-2 x * 3-3 x * 3-4 x * 3-5
0..32 0..32 32..64 64..96 96..128 128..160
32..48 160..176 176..192 192..208 208..224 224..240
48..52 240..244 244..248 248..252 252..256 N/A

And then byte 52 and 53 contain the float16 scale in little-endian.

Values are stored in fixed point to allow extracting the most significant digit first. This is explained in https://compilade.net/blog/ternary-packing.

Structure of TQ2_0

This type was originally inspired by the Q2_2 type made by @Eddie-Wang1120, but the block size, the order, and the mapping of the values are different.

TQ2_0 started as an experiment to see how fast a 2-bit type can be compared to a 1.6-bit type on compute-bound hardware.

This packs each ternary value in 2 bits, which means each byte contains 4 values.

The ternary values are stored unsigned, so {-1, 0, 1} is stored as {0, 1, 2}.

Again, the ranges use the start..end notation where the start is inclusive and the end is exclusive, and the ranges of a row should be read as being zipped together (they advance in parallel in lockstep).

byte x << 6 x << 4 x << 2 x << 0
0..32 96..128 64..96 32..64 0..32
32..64 224..256 192..224 160..192 128..160

And then byte 64 and 65 contain the float16 scale in little-endian.

TODO

  • Implement Numpy (de)quantization for TQ1_0 and TQ2_0
  • Allow convert_hf_to_gguf.py to directly convert a ternary model to a ternary encoding
    • Using f16 for the token embeddings and output tensor because Q4_K and Q6_K quantization is not yet supported by gguf-py. This means llama-quantize needs to be used to quantize these tensors.
    • Make it more obvious that the models should go through llama-quantize afterwards.
      • Maybe use other type names, like TQ1_0_L or something?
  • Decide whether the float16 scale should be before or after the packed weights
    • I'd prefer it after because I feel like the scales are read after the weights in dot products, but the convention with the other types (except for Q2_K, Q3_K and Q6_K) is to keep the scale before.
    • Okay, I've decided the scales should stay at the end.
  • More graceful fallback conversion with llama-quantize
    • Using Q4_0 as a fallback type, because the smallest symmetric quant type is Q8_0 but it's a bit big, so Q4_0 it is (even though it's not ideal). Only relevant when row sizes are not multiples of 256.
  • Unify the __ARM_FEATURE_DOTPROD variants of the dot products of TQ1_0 and TQ2_0 with their bare __ARM_NEON variants to reduce code duplication.
  • Test TQ1_0 and TQ2_0 for correctness on an ARM CPU which supports dot product instructions
    • Tested on an AWS t4g.small instance.
    • Also test relative performance for fun
  • Should TQ1_0's first 48 bytes be divided in 3 sub-blocks of 16 bytes (80 elements) instead of one of 32 bytes (140 elements) and one of 16 bytes?
    • I've done the 32-16 split to use 256-bit registers on AVX2 for the pow3 shifts for at least the 32 byte part, but 16-16-16 would be more regular, although it would require using 128-bit registers for all the ternary shifts. Not sure if there's a performance difference.
  • Rename references to "BitNet 1.58b" to "BitNet b1.58". The "b" comes before in the paper.
  • Find a naming convention for BitNet quants and rename Q1_3 and Q2_2
    • They were renamed and redesigned as TQ1_0 and TQ2_0.
  • Decide to keep or to remove the optimization for ggml_mul when the broadcasted tensor only has a single element
  • Fix Android CI build issues.
    • It was apparently a problem with Arm 32-bit. Fixed in 8fbd593

compilade ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b
bd807499
compilade ggml-quants : faster 1.625 bpw AVX2 vec_dot
7ef4254a
compilade ggml-quants : substract 1 when back in epi8
48b73b84
compilade ggml-quants : Q2_2 now faster than Q4_K on with AVX2
ef1e345c
compilade ggml-quants : cleanup Q1_3 code formatting
638ad52f
compilade ggml-quants : ARM NEON vec_dot for q2_2 and q1_3
9465ec6e
compilade ggml-quants : use ceiling division when quantizing q1_3
89dc3b25
compilade convert-hf : simplify BitNet pre-quantization
961e2938
compilade convert-hf : allow converting the weird BitNet 1.3B
09961499
compilade compilade added enhancement
compilade compilade added python
compilade compilade added Review Complexity : High
compilade compilade added ggml
compilade compilade added Tensor Encoding Scheme
compilade compilade force pushed from 4522ed78 to 09961499 1 year ago
github-actions github-actions added testing
github-actions github-actions added examples
Eddie-Wang1120
Eddie-Wang11201 year ago (edited 1 year ago)👍 5

Wonderful job! I'm wondering if this PR can merge into the master branch, it would be so good if users of llama.cpp can use Q2_2 and Q1_3 conveniently.

compilade compilade changed the title ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b ggml-quants : 1.625 bpw ternary packing for BitNet b1.58 1 year ago
compilade bitnet : replace 1.58b with b1.58, as in the paper
bfd2f21f
compilade ggml-quants : fix build failure on Windows
ec50944b
compilade
compilade commented on 2024-06-29
examples/quantize/quantize.cpp
2626 { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
2727 { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
2828 { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
29
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet b1.58", },
30
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", },
compilade1 year ago (edited 1 year ago)👍 1

Regarding the names of the new quant types, since these are quite specific to BitNet models, I was thinking to name them with something starting with QB, a bit like suggested in #5761 (comment).

I'll first be describing what I want from the naming scheme, then I'll attempt to make it work.

The naming scheme should have room for:

  • Ternary types in {-1, 0, 1}
    • 1.625 bpw quant with a block size of 64, with 13 bytes per block
      • To make the smallest possible lossless BitNet b1.58 model files
      • Uses Q8_0 as its vec_dot_type (for the activations)
      • (It's technically possible to store a float16 scale in the leftover bits in the last byte of 16 consecutive blocks (this means 1024 elements minimum per row), although it can't really be extracted with SIMD)
    • 2.000 bpw quant with a block size of 32, with 8 bytes per block
      • For maximal performance
      • Uses Q8_0 as its vec_dot_type (for the activations)
    • 2.000 bpw quant with a block size of 64, with 16 bytes per block, and a float16 scale
      • Values would be packed similarly to the 1.625 bpw type, but with an extra byte and a row-wise float16 scale duplicated in each block.
    • 2.000 bpw quant with a block size of 4, with 1 byte per block
      • For weirdly-shaped models like the 1.3B BitNet b1.58 model
      • Needs a compatible vec_dot_type
        • float types are slower than integer types for this
  • Binary types in {-1, 1}
    • 1 bpw type
  • Binary types in {0, 1}
    • Are there models which use this?
  • 8-bit activation with a row-wise scale
    • 8.5 bpw like Q8_0, but all the scales of a row are the same
      • Would allow reducing the number of float32 operations in the vec_dot of the above types.
    • 10 bpw, 5 bytes per block of 4 elements, with a weird layout which only uses blocks to get a big enough buffer, with a single float32 scale and some padding before all row elements, aligned and contiguous.
      • For use with the weird 2.000 bpw type, and also maybe the other ones for best performance.

So the naming scheme could be:

  • QB<x>_<y>
    • where <x> is the floor of the expected bpw of the type
    • where <y> is
      • 0 binary type, {0, 1}
        • except for QB8_0 which is like Q8_0 but with a guaranteed duplicated row-wise scale
      • 1 binary type, {-1, 1}
      • 2 ternary type using some kind of binary-coded ternary
      • 3 ternary type with fixed-point packed values
      • 4 weird type with a block size of 4

Which for the previously-mentioned possible BitNet types would mean:

proposed name Range bits per weight block size bytes row-wise scale current name
QB1_3 {-1, 0, 1} 1.625 64 13 1.0f Q1_3
QB2_2 {-2, -1, 0, 1} 2.000 32 8 1.0f Q2_2
QB2_3 {-1, 0, 1} 2.000 64 16 f16
QB2_4 {-2, -1, 0, 1} 2.000 4 1 1.0f
QB1_1 {-1, 1} 1.000 ? ?/8 1.0f
QB1_0 {0, 1} 1.000 ? ?/8 1.0f
QB8_0 [-127, 127] 8.5 32 34 f16
QB8_4 [-127, 127] 10 4 5 f32, weird layout

I'm not saying these should all exist, though, only that the naming scheme should not be too limiting for possible future extensions (which might not exist anyway due to lack of time).

So I think I'll rename Q1_3 to QB1_3, and Q2_2 to QB2_2. Anyone has comments on this? Or a better naming scheme for the new BitNet quant types?

candre231 year ago👍 1

If it were me, considering this only works with bitnet models and nothing else, I'd want the designations to be exceptionally clear that they are different and shouldn't be used on just anything. "QB" is good, but I'd take it a step further and remove the Q entirely. As bitnet is being colloquially referred to as a "1-bit" model, B1 makes more sense. Considering the plausible range for weights, I'd cut it off at tenths and ditch the decimal. This leaves plenty of room for variations, while making the native BPW very clear. I feel this is superior to the arbitrary "_2" and "_3" subtypes.

So what I would propose is:

1.625bpw = B1_16
2.000bpw = B1_20

compilade ggml-quants : attempt to fix Arm 32-bit support
8fbd5930
Green-Sky
Green-Sky1 year ago (edited 1 year ago)👀 1

@compilade and @Eddie-Wang1120 continuing the race to the bottom 🥳 , glorious.

Did some quick testing with the 3B model and it looks very good.

model size params backend threads test t/s
bitnet 3B Q1_3 - 1.625 bpw for BitNet b1.58 729.64 MiB 3.32 B BLAS 12 pp512 78.40 ± 0.27
bitnet 3B Q1_3 - 1.625 bpw for BitNet b1.58 729.64 MiB 3.32 B BLAS 12 tg128 38.16 ± 0.04
bitnet 3B Q2_2 - 2.000 bpw for BitNet b1.58 873.65 MiB 3.32 B BLAS 12 pp512 73.35 ± 6.23
bitnet 3B Q2_2 - 2.000 bpw for BitNet b1.58 873.65 MiB 3.32 B BLAS 12 tg128 36.86 ± 0.12

What surprises me a little, after reading about q2_2 being faster, is that q1_3 seems to be faster with the setup I used here. Will investigate further.

edit: also updated the files at https://huggingface.co/Green-Sky/bitnet_b1_58-3B-GGUF , for anyone else willing to test.

netrunnereve
netrunnereve1 year ago👍 3

Did a bit of testing myself, it runs and generates well but unfortunately it's the undertrained models rather than our implementation that's holding back BitNet adoption. For me Q1_3 is slower but this computer is CPU rather than memory bound.

model size params backend threads test t/s
bitnet 3B Q1_3 - 1.625 bpw for BitNet 1.58b 729.64 MiB 3.32 B CPU 4 pp512 15.15 ± 0.07
bitnet 3B Q1_3 - 1.625 bpw for BitNet 1.58b 729.64 MiB 3.32 B CPU 4 tg128 9.87 ± 0.65
bitnet 3B Q2_2 - 2.000 bpw for BitNet 1.58b 873.65 MiB 3.32 B CPU 4 pp512 19.25 ± 0.44
bitnet 3B Q2_2 - 2.000 bpw for BitNet 1.58b 873.65 MiB 3.32 B CPU 4 tg128 13.07 ± 0.28
bitnet 3B Q4_0 1.79 GiB 3.32 B CPU 4 pp512 18.44 ± 0.40
bitnet 3B Q4_0 1.79 GiB 3.32 B CPU 4 tg128 5.87 ± 0.12

I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants so that we can handle more than 64 bits of Q2_2 quants in each dot product loop. Aside from that I can't find any further way to improve that AVX implementation, and while it's ironic that we're using a madds instruction there when BitNet technically doesn't require multiplication that looks like the fastest way to dot the activations and ternary weights.

compilade
compilade1 year ago

I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants

Can't go with bigger blocks than 64 elements or else the 3B model won't be fully quantizable. (Its FFN size is 8640 (which factors into 2 2 2 2 2 2 3 3 3 5))

Its current block size is 32, which is the same as its vec_dot_type, Q8_0.

What would also help with performance would be to somehow use an 8-bit vec_dot_type having a single float scale per row. Might be interesting to explore later, but ggml does not have row-wise quant types yet, although this could still be done with a block quant.

it's ironic that we're using a madds instruction

Yeah, with AVX2, there are no good widening addition instructions like on ARM NEON, so _mm256_maddubs_epi16 is used for that.

Meanwhile, NEON doesn't have the equivalent of _mm_sign_epi8, so it needs to use multiplications or conditional masks, which are both slower than a dedicated instruction doing zeroing and sign flipping like in SSSE3.

ggerganov
ggerganov commented on 2024-07-07
ggml/src/ggml-quants.c
672 int8_t x2 = (int8_t)x[i*qk + 2*qk/4 + j];
673 int8_t x3 = (int8_t)x[i*qk + 3*qk/4 + j];
674
675
const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3;
676
const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3;
677
const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3;
678
const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3;
ggerganov1 year ago (edited 1 year ago)

As proposed, the type utilizes only 3 of the 4 possible values. I was thinking that the Q2_2 type would work the same as Q4_0, but assumes amax == 1.0f:

void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
    static const int qk = QK2_2;

    assert(k % qk == 0);

    const int nb = k / qk;

    for (int i = 0; i < nb; i++) {
        float amax = 0.0f; // absolute max
        float max  = 0.0f;

        for (int j = 0; j < qk; j++) {
            const float v = x[i*qk + j];
            if (amax < fabsf(v)) {
                amax = fabsf(v);
                max  = v;
            }
        }

        // assume amax = 1.0f
        max /= amax;

        const float d  = max / -2;
        const float id = d ? 1.0f/d : 0.0f;

        for (int j = 0; j < qk/4; ++j) {
            const float x0 = x[i*qk + 0*qk/4 + j]*id;
            const float x1 = x[i*qk + 1*qk/4 + j]*id;
            const float x2 = x[i*qk + 2*qk/4 + j]*id;
            const float x3 = x[i*qk + 3*qk/4 + j]*id;

            const uint8_t xi0 = MIN(3, (int8_t)(x0 + 2.5f));
            const uint8_t xi1 = MIN(3, (int8_t)(x1 + 2.5f));
            const uint8_t xi2 = MIN(3, (int8_t)(x2 + 2.5f));
            const uint8_t xi3 = MIN(3, (int8_t)(x3 + 2.5f));

            y[i].qs[j]  = xi0;
            y[i].qs[j] |= xi1 << 2;
            y[i].qs[j] |= xi2 << 4;
            y[i].qs[j] |= xi3 << 6;
        }
    }
}

(not tested, just pattern matching the existing quantize_row_q4_0_reference())

Edit: just realized the above would not work. We have assume that max == 1.0f, not amax, so:

const float max = 1.0f;
...
ggerganov
ggerganov commented on 2024-07-07
Conversation is marked as resolved
Show resolved
ggml/src/ggml.c
1004110065 GGML_ASSERT( nb0 == sizeof(float));
1004210066 GGML_ASSERT(nb00 == sizeof(float));
1004310067
10044 if (nb10 == sizeof(float)) {
10068
if (ggml_nelements(src1) == 1) {
10069
float scale = ((float *) src1->data)[0];
10070
for (int64_t ir = ith; ir < nr; ir += nth) {
10071
if (dst->data != src0->data) {
10072
// src0 is same shape as dst => same indices
10073
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
10074
}
10075
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
10076
}
ggerganov1 year ago👍 4

It's ok to keep this optimization

compilade ggml : add some informative comments in q1_3 vec_dot
dd3e62a7
compilade Merge branch 'master' into compilade/bitnet-ternary
79a278e9
compilade
compilade1 year ago🎉 5

Whew, it has been a month since I last touched this, I got distracted for a bit.

(tl;dr at the end)

Now that new ternary models like TriLMs exist (https://arxiv.org/abs/2407.12327), which use multiple scales per tensors and which (fortunately) have all tensor dimensions divisible by 256 🎉, I think I should add a ternary type with 256 elements per block and a block-wise f16 scale. That would result in 1.6875 bpw, which sounds very reasonable to me.

Another ternary type with a scale but with a smaller block size (64) might be useful for compatibility with the BitNet b1.58 models from the 1bitLLM team (because their model dimensions are not divisible by 256), and would be 1.875 bpw or 2.0 bpw depending on whether padding 15 bytes of data to 16 bytes is better for performance.

These should have a similar inference speed as Q1_3, since they will use a similar packing scheme.

I'm not sure if it's worth it to keep the scale-less ternary quant types; I feel like they require too much special handling in the model graphs and in the convert script. It might be okay for BitNetForCausalLM, but not for some newer models like TriLMs which use LlamaForCausalLM, AKA not a ternary-specific architecture.

So I'll be proposing 4 (starting with 2) types, with yet another attempt at a naming scheme1 for ternary quants,
this time matching the regex TQ\d(_\dF?)?:

  • TQ1_0
    • ternary quant with 256 elements per block at 1.6875 bpw.
    • the packing would be similar to Q1_3, but repeated 4 times, and with a f16 scale.
    • its vec_dot_type could be Q8_K
  • TQ1_0F
  • (maybe) TQ2_0
    • ternary quant with 256 elements per block at 2.0625 bpw.
    • similar packing as Q2_2, so it should be performant, unless on platforms where the misalignment from the 2 bytes of the scale has some effect.
    • its vec_dot_type could be Q8_K
    • much simpler than IQ2_XSS, which can't even represent 0 unless the whole block is 0.
  • (maybe) TQ2_0F
    • same as TQ1_0F, but based on Q2_2 instead of Q1_3.
    • 2.25 bpw

Note that IQ2_XXS is already a 256-element type with similar properties as TQ2_0, although IQ2_XXS's packing scheme is much more complicated and I feel like its reliance on iq2xxs_grid makes it unnecessarily slower than it could be.2

I'll work on at least TQ1_0 and TQ1_0F in the next days, but I might get distracted. I'm doing this as a hobby in my free time, so it's possible that my priorities shift depending on external factors. This means anyone interested should feel free to ping me if I seem to have forgotten this again.

TL;DR: I think I'll replace the scale-less Q1_3 and Q2_2 with ternary types with a block-wise scale, which should allow supporting both BitNet b1.58 and TriLMs, while also simplifying the conversion for BitNet b1.58 because separate scale tensors won't be needed anymore.

Footnotes

  1. Some rationale for the naming scheme: using a special prefix to note that these are special-purpose, TQ stands for "ternary quant", not using QT to avoid confusion with https://www.qt.io/, and also because the IQ quants also prefix Q with a letter. I'm using _0 as suffix to mean that it has a scale similarly to Q8_0.

  2. Okay, I've read a bit about IQ2_XXS, and it seems slightly over-engineered and totally not intended for ternary models. Basically, it strongly relies on a lookup table (iq2xxs_grid), which contains 3 possible values in each byte: 0x08, 0x19 or 0x2b (8, 25, 45, respectively). This looks like where the absolute values of the elements comes from (before being scaled and signed). This means 0 is not representable unless the whole block is 0.

Green-Sky
Green-Sky364 days ago👍 1

@compilade Keep up the good work. You are a hero making living on the edge affordable 😄 .
Beside the others here of course... 😉

Not sure if anyone has noticed, but meta(facebook) changed the license for llama3.1 to allow training on outputs, which would allow for distillation.
So now I am waiting to see a bitnet distillation of the new 3.1 llamas pop up (hopefully).

mofosyne
mofosyne364 days ago

@compilade btw quick question regarding the packing structure of these encoding arrangements. Is there a consistent way to extract the bit pattern structure from the source code? It's a bit hard to grok the superblock, blocks and how bits are being packed for documentation. Ideally too I would like such documentation to be autogenerated as well, but until I can understand the basics from the C struct... it's a bit hard to get started.

ggerganov
ggerganov364 days ago

The plan sounds good. I wouldn't worry about the fallback types - we already have a workaround via padding for such kind of models, plus I doubt there will be much of those in the future.

compilade
compilade364 days ago

we already have a workaround via padding for such kind of models

@ggerganov While it mostly works, padding like in e9f2abf isn't correct with ggml_rms_norm, because the row size is used to calculate the mean.

https://github.com/ggerganov/llama.cpp/blob/75af08c475e285888f66556d0f459c533b7deb95/ggml/src/ggml.c#L11813

To make padding work properly, there would need to be some special handling to make it possible to use ne[0] values which are not multiples of the block size (like making ggml_row_size round up).

The GGUF file format should already support that, since the tensor offsets don't directly depend on their size.

But GGUFWriter would need to avoid assuming a lossless round-trip between shape and byte shape.

Quantization and dequantization would need to be adapted, because the functions currently assume ne[0] is a multiple of the block size. But the quantize_row_*_ref functions don't necessarily know ne[0] directly (they get the total element count in a chunk of rows), but that should be easy enough to adapt with doing one call per row when padding is needed, a bit like applying importance matrices is done one row at a time. Or padding could be handled outside, but this would (momentarily) use more memory for the padded f32 copies (unpadding can be done with views).

Dot products would need no change if the padding values are equivalent to zero (this won't work for IQ2_XXS and likely other IQ types which can't represent zero).

I wouldn't worry about the fallback types

Understood. I agree with adding fewer types. And using padding could even let the cursed https://huggingface.co/1bitLLM/bitnet_b1_58-xl be quantized with its weird FFN size of 5460 which factors into 2 2 3 5 7 13.

I'll start with not handling padding, because it would affect other types too (notably Q8_K), and might be more appropriate in a separate PR.

compilade
compilade364 days ago

Is there a consistent way to extract the bit pattern structure from the source code? It's a bit hard to grok the superblock, blocks and how bits are being packed for documentation. Ideally too I would like such documentation to be autogenerated as well, but until I can understand the basics from the C struct... it's a bit hard to get started.

@mofosyne

No, unfortunatley, I don't think this can be easily automated. Sometimes a single field in the structs stores multiple types of values, like in Q4_K where block_q4_K.scales stores 6-bit scales and mins in some pattern1. The easiest way to understand what the bits mean is to have a look at the respective dequantize_row function of each type.

Footnotes

  1. The 12 bytes in Q4_K .scales are packed a bit like this, where the uppercased letters are bits for the scales and lowercased letters are the bits of the mins:

     0: EEAAAAAA
     1: FFBBBBBB
     2: GGCCCCCC
     3: HHDDDDDD
     4: eeaaaaaa
     5: ffbbbbbb
     6: ggcccccc
     7: hhdddddd
     8: eeeeEEEE
     9: ffffFFFF
    10: ggggGGGG
    11: hhhhHHHH
    

    Source: https://github.com/ggerganov/llama.cpp/blob/75af08c475e285888f66556d0f459c533b7deb95/ggml/src/ggml-quants.c#L1891-L1898

mofosyne
mofosyne364 days ago (edited 363 days ago)

@compilade thanks for the explanation it is interesting to see that the bits are split into 2bit and 4bits and uses only bitwise actions. Is this because it's preferred over packing each 6bit scale in sequential order, because each access is aligned or is cheaper to use bitwise operations?


edit: Ah likely to be more friendlier for parallel processing in gpu etc...

ggerganov
ggerganov363 days ago

@ggerganov While it mostly works, padding like in e9f2abf isn't correct with ggml_rms_norm, because the row size is used to calculate the mean

Correct, the norm can be applied on a view having the original size though (1D tensors used for normalisations are never quantised).

compilade ggml : add TQ1_0 and TQ2_0 ternary quantization types
77b8f84a
compilade
compilade363 days ago (edited 363 days ago)👍 17😄 2🎉 2❤ 2🚀 4👀 3

I've made some preliminary performance (speed) tests with TQ1_0 and TQ2_0, and TQ1_0 is faster than Q1_3, now around the speed of Q8_0, while TQ2_0 got a very big perf boost and is twice as fast as TQ1_0, which makes it by far the fastest quant type (around 2x faster than Q8_0, and 1.7x faster than Q2_K)1, at least with AVX2 on my machine. Bigger block sizes do pay off!

(And Q8_K is a very good vec_dot_type, with a f32 scale and even pre-computed sums)

Note that this is about the vec_dot speed and not the overall speed, although it's usually where most of the compute time is spent.

The formats of TQ1_0 and TQ2_0 are a bit different than what I initially planned, to make the data more convenient to access in the AVX2 vec_dot. Something nice is that unlike Q1_3, TQ1_0 does not rely on reading past the buffer (Q1_3 has 13 byte blocks which were read in 16 byte chunks).

A possible future improvement for the AVX2 vec_dot of TQ1_0 would be to test if 16-bit multiplies and permutes are faster or not than more elaborate ways to shift 8-bit values by powers of 3 (AVX2 does not have non-widening 8-bit multiplies), but both approaches were mostly similar in performance on my machine, so I went with the 8-bit operations.

I'll port TQ1_0 and TQ2_0 to ARM NEON in the next days, and I'll remove Q1_3 and Q2_2 after making comparisons on low-end ARM devices.


Is this because it's preferred over packing each 6bit scale in sequential order, because each access is aligned or is cheaper to use bitwise operations?

@mofosyne I had no part in the decision of the scale packing in Q4_K, but I think it's like this because indexing is only done at the byte level, so packing and unpacking 6-bit values has to use bitwise operations. Pointers can only jump at a minimum of a byte at a time. Also when making the vec_dot of Q1_3 I've noticed that shuffles are surprisingly as fast as additions in SIMD.

Footnotes

  1. Proof:

    Output of test-quantize-perf (click to expand)
    $ for t in q4_0 q8_0 q4_K q2_K tq2_0 tq1_0 q1_3 q2_2; do ./bin/test-quantize-perf --op vec_dot_q --type $t -i 10000000; done
    q4_0
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      3.03
          avg cycles/32 vals   :      3.33
          float32 throughput   :     52.88 GB/s
          quantized throughput :      7.44 GB/s
    
    q8_0
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      2.24
          avg cycles/32 vals   :      2.51
          float32 throughput   :     68.26 GB/s
          quantized throughput :     18.13 GB/s
    
    q4_K
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      2.22
          avg cycles/32 vals   :      2.68
          float32 throughput   :     62.68 GB/s
          quantized throughput :      8.81 GB/s
    
    q2_K
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      1.75
          avg cycles/32 vals   :      1.99
          float32 throughput   :     81.82 GB/s
          quantized throughput :      6.71 GB/s
    
    tq2_0
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      0.83
          avg cycles/32 vals   :      0.95
          float32 throughput   :    144.50 GB/s
          quantized throughput :      9.31 GB/s
    
    tq1_0
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      2.11
          avg cycles/32 vals   :      2.29
          float32 throughput   :     71.35 GB/s
          quantized throughput :      3.76 GB/s
    
    q1_3
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      2.94
          avg cycles/32 vals   :      3.46
          float32 throughput   :     50.02 GB/s
          quantized throughput :      2.54 GB/s
    
    q2_2
      vec_dot_q
        4096 values (0.02 MB)
          min cycles/32 vals   :      2.12
          avg cycles/32 vals   :      2.33
          float32 throughput   :     73.31 GB/s
          quantized throughput :      4.58 GB/s
    
compilade ggml : even faster TQ2_0
560873f3
compilade ggml : also faster TQ1_0
e9719576
flatsiedatsie
flatsiedatsie362 days ago

(some discussion around compilade's improvement can be found on Reddit)

compilade ggml : fix build issues in certain environments
a6dd6994
compilade ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0
5417089a
compilade ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat
45719a24
compilade
compilade362 days ago (edited 361 days ago)👍 3🚀 1

I've tested that a round-trip quantization between TQ1_0 and TQ2_0 is lossless, which means one can always be made from the other.

$ ./build/bin/llama-quantize models/trilm-390M-f16.gguf models/trilm-390M-tq1_0.gguf tq1_0
$ ./build/bin/llama-quantize models/trilm-390M-f16.gguf models/trilm-390M-tq2_0.gguf tq2_0
$ ./build/bin/llama-quantize --allow-requantize models/trilm-390M-tq1_0.gguf models/trilm-390M-tq2_0-requant.gguf tq2_0
$ ./build/bin/llama-quantize --allow-requantize models/trilm-390M-tq2_0-requant.gguf models/trilm-390M-tq1_0-roundtrip.gguf tq1_0
$ cd models
$ sha256sum trilm-390M-tq*
e4c622fb10dcfa30d427eb94eb08ffdcbde8ef3683a2b43a1b1eac8ab6e3e67f  trilm-390M-tq1_0.gguf
e4c622fb10dcfa30d427eb94eb08ffdcbde8ef3683a2b43a1b1eac8ab6e3e67f  trilm-390M-tq1_0-roundtrip.gguf
4edaaa33f8d7ffeaac72d758bf0e253512128a4a872a9c428bf337abb21a64be  trilm-390M-tq2_0.gguf
4edaaa33f8d7ffeaac72d758bf0e253512128a4a872a9c428bf337abb21a64be  trilm-390M-tq2_0-requant.gguf

I've also added ARM NEON implementations of vec_dot for TQ1_0 and TQ2_0, but the relative speedup on a Raspberry Pi 4 B is less impressive than with AVX2 on my laptop. There might still be ways to optimize the use of ARM NEON in there.

Still, it's decent, at 1.6x the speed of Q8_0 for TQ2_0. But the RPi4 is very memory bound (with a bandwidth only around 3GB/s), so actual inference speed is relatively much better with smaller types.

But I'm happy that TQ1_0 is 1.75x as fast as Q1_3 on that machine. The gap between TQ1_0 and TQ2_0 is also smaller than with AVX2.

Output of test-quantize-perf on a RPi4 (click to expand)
$ for t in q4_0 q8_0 q4_K q2_K tq2_0 tq1_0 q1_3 q2_2; do ./bin/test-quantize-perf --op vec_dot_q --type $t -i 2000000; done                                                                                                         
q4_0                                                                                                                                                                                                                                                                       
  vec_dot_q                                                                                                                          
    4096 values (0.02 MB)                                                                                                            
      min cycles/32 vals   :      0.00                                                                                               
      avg cycles/32 vals   :      0.00                                                                                               
      float32 throughput   :      7.82 GB/s                                                                                          
      quantized throughput :      1.10 GB/s                                                                                          
                                                                                                                                     
q8_0                                                                                                                                 
  vec_dot_q                                                                                                                          
    4096 values (0.02 MB)                                                                                                            
      min cycles/32 vals   :      0.00                                                                                               
      avg cycles/32 vals   :      0.00                                                                                               
      float32 throughput   :      9.57 GB/s                                                                                          
      quantized throughput :      2.54 GB/s                                                                                          
                                                                                                                                     
q4_K                                                                                                                                 
  vec_dot_q                                                                                                                          
    4096 values (0.02 MB)                                                                                                            
      min cycles/32 vals   :      0.00                                                                                               
      avg cycles/32 vals   :      0.00                                                                                               
      float32 throughput   :      9.38 GB/s                                                                                          
      quantized throughput :      1.32 GB/s                                                                                          
                                                                                                                                     
q2_K                                                                                                                                 
  vec_dot_q                                                                                                                          
    4096 values (0.02 MB)                                                                                                            
      min cycles/32 vals   :      0.00                                                                                               
      avg cycles/32 vals   :      0.00                                                                                               
      float32 throughput   :      9.64 GB/s
      quantized throughput :      0.79 GB/s
                                                                  
tq2_0                                
  vec_dot_q                              
    4096 values (0.02 MB)                
      min cycles/32 vals   :      0.00                                             
      avg cycles/32 vals   :      0.00                                             
      float32 throughput   :     15.35 GB/s                                        
      quantized throughput :      0.99 GB/s                                        
                                                                  
tq1_0                                          
  vec_dot_q                                    
    4096 values (0.02 MB)                      
      min cycles/32 vals   :      0.00                                                         
      avg cycles/32 vals   :      0.00                                                         
      float32 throughput   :     11.82 GB/s                                                    
      quantized throughput :      0.62 GB/s
                                                                  
q1_3                                                              
  vec_dot_q                                                       
    4096 values (0.02 MB)                                         
      min cycles/32 vals   :      0.00                            
      avg cycles/32 vals   :      0.00
      float32 throughput   :      6.75 GB/s
      quantized throughput :      0.34 GB/s
                                                                  
q2_2                                                              
  vec_dot_q                                                       
    4096 values (0.02 MB)                                                                                                            
      min cycles/32 vals   :      0.00
      avg cycles/32 vals   :      0.00
      float32 throughput   :      9.14 GB/s
      quantized throughput :      0.57 GB/s

The next steps are to remove Q1_3 and Q2_2, and to adapt the convert script to let it convert directly to at least one of TQ1_0 or TQ2_0.

compilade compilade marked this pull request as draft 362 days ago
compilade ggml : remove q1_3 and q2_2
04eec581
compilade compilade changed the title ggml-quants : 1.625 bpw ternary packing for BitNet b1.58 ggml-quants : ternary packing for TriLMs and BitNet b1.58 360 days ago
Green-Sky
Green-Sky359 days ago (edited 359 days ago)

I saw compilade remove the old bitnet quants, so I decided it was time for another round of tests.

Since the large bitnet repro model does not work with the new quants (as explained in the OP), I switched to the TriLM_3.9B model.

quant ppl ppl@300 filesize
f16 11.1532 +/- 0.07854 11.0180 7.5G
q8_0 11.1489 +/- 0.07851 11.015 4.0G
q4_0 11.4797 +/- 0.08058 11.3249 2.2G
q4_k 11.1559 +/- 0.07854 11.0223 2.3G
tq2_0 11.1558 +/- 0.07853 11.0200 1.1G
tq1_0 11.1558 +/- 0.07853 11.0200 949M

I added ppl at step 300 for reference to speed up future ppl calculations.

Note: Offloading tq2_0 layers to vram(cuda) improved the time by ~20%. It was still 10x slower than q4_k though.

As always I used default settings calculating perplexity over 560 chunks, n_ctx=512, batch_size=2048, n_seq=4

edit: uploaded some quantized files again https://huggingface.co/Green-Sky/TriLM_3.9B-GGUF

Green-Sky
Green-Sky359 days ago (edited 359 days ago)❤ 5

Benchmarks:

model size params backend thrds/ngl test t/s
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B CPU 12 pp512 79.21 ± 0.15
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B CPU 12 tg128 38.56 ± 0.10
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CPU 12 pp512 146.17 ± 0.71
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CPU 12 tg128 33.89 ± 0.07
llama ?B Q8_0 3.95 GiB 3.99 B CPU 12 pp512 80.79 ± 2.35
llama ?B Q8_0 3.95 GiB 3.99 B CPU 12 tg128 10.30 ± 0.11
llama ?B Q4_0 2.13 GiB 3.99 B CPU 12 pp512 62.81 ± 4.24
llama ?B Q4_0 2.13 GiB 3.99 B CPU 12 tg128 17.89 ± 0.03
llama ?B Q4_K - Medium 2.26 GiB 3.99 B CPU 12 pp512 80.85 ± 0.20
llama ?B Q4_K - Medium 2.26 GiB 3.99 B CPU 12 tg128 16.93 ± 0.23
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B BLAS 12 pp512 57.12 ± 0.92
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B BLAS 12 tg128 38.05 ± 0.06
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B BLAS 12 pp512 55.40 ± 1.81
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B BLAS 12 tg128 33.35 ± 0.29
llama ?B Q8_0 3.95 GiB 3.99 B BLAS 12 pp512 47.88 ± 6.64
llama ?B Q8_0 3.95 GiB 3.99 B BLAS 12 tg128 10.09 ± 0.33
llama ?B Q4_0 2.13 GiB 3.99 B BLAS 12 pp512 51.64 ± 4.37
llama ?B Q4_0 2.13 GiB 3.99 B BLAS 12 tg128 18.01 ± 0.07
llama ?B Q4_K - Medium 2.26 GiB 3.99 B BLAS 12 pp512 63.96 ± 1.08
llama ?B Q4_K - Medium 2.26 GiB 3.99 B BLAS 12 tg128 17.49 ± 0.07
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B CUDA 0 pp512 78.49 ± 0.34
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B CUDA 0 tg128 38.48 ± 0.49
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B CUDA 99 pp512 82.15 ± 0.22
llama ?B TQ1_0 - 1.69 bpw ternary 946.45 MiB 3.99 B CUDA 99 tg128 11.65 ± 0.05
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CUDA 0 pp512 143.04 ± 0.79
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CUDA 0 tg128 34.32 ± 0.07
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CUDA 99 pp512 155.20 ± 2.79
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CUDA 99 tg128 9.73 ± 0.03
llama ?B Q8_0 3.95 GiB 3.99 B CUDA 0 pp512 833.59 ± 10.54
llama ?B Q8_0 3.95 GiB 3.99 B CUDA 0 tg128 10.28 ± 0.08
llama ?B Q8_0 3.95 GiB 3.99 B CUDA 99 pp512 2442.82 ± 12.50
llama ?B Q8_0 3.95 GiB 3.99 B CUDA 99 tg128 62.65 ± 0.99
llama ?B Q4_0 2.13 GiB 3.99 B CUDA 0 pp512 1121.50 ± 3.99
llama ?B Q4_0 2.13 GiB 3.99 B CUDA 0 tg128 17.76 ± 0.15
llama ?B Q4_0 2.13 GiB 3.99 B CUDA 99 pp512 2387.19 ± 106.81
llama ?B Q4_0 2.13 GiB 3.99 B CUDA 99 tg128 97.52 ± 0.87
llama ?B Q4_K - Medium 2.26 GiB 3.99 B CUDA 0 pp512 1054.77 ± 15.91
llama ?B Q4_K - Medium 2.26 GiB 3.99 B CUDA 0 tg128 17.30 ± 0.06
llama ?B Q4_K - Medium 2.26 GiB 3.99 B CUDA 99 pp512 2272.11 ± 13.74
llama ?B Q4_K - Medium 2.26 GiB 3.99 B CUDA 99 tg128 90.54 ± 0.19

The prompt processing speed of the TQ2_0 is very remarkable. Both with AVX2 and CUDA. But still pales in comparison to optimized CUDA code.

(Open?)BLAS generally just performing worse. A shame, since it's the default.

CPU: AMD Ryzen 9 PRO 3900 12-Core Processor
GPU: NVIDIA GeForce RTX 2070 (mobile but it's not mentioned anywhere)

CPU only with AVX2

system_info: n_threads = 12 / 24 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |

CPU with BLAS

system_info: n_threads = 12 / 24 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |

CUDA

system_info: n_threads = 12 / 24 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |`
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5, VMM: yes
compilade ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency
f034aa1b
mirek190
mirek190359 days ago (edited 359 days ago)

I see perplexity looks too good for tq1_0 and tq2_0 .... too good to be true ;)
I wonder how it will impact on reasoning and creativity.

compilade
compilade359 days ago

I see perplexity looks too good for tq1_0 and tq2_0 .... too good to be true ;)

Keep in mind these types were only tested on models which were trained to have ternary weights, like BitNet b1.58 and TriLMs.

TQ1_0 and TQ2_0 pretty much losslessly encode ternary weights, and the perplexity difference from F16 comes from the Q4_K token embeddings and Q6_K output tensor, as well as the 8-bit activations.

The perplexity on non-ternary models would be extremely bad.

mirek190
mirek190359 days ago

I see perplexity looks too good for tq1_0 and tq2_0 .... too good to be true ;)

Keep in mind these types were only tested on models which were trained to have ternary weights, like BitNet b1.58 and TriLMs.

TQ1_0 and TQ2_0 pretty much losslessly encode ternary weights, and the perplexity difference from F16 comes from the Q4_K token embeddings and Q6_K output tensor, as well as the 8-bit activations.

The perplexity on non-ternary models would be extremely bad.

Yes I understand that .
I'm curious how good such models can be in creativity and reasoning trained such way from the beginning.

Green-Sky
Green-Sky358 days ago

Yes I understand that . I'm curious how good such models can be in creativity and reasoning trained such way from the beginning.

You can checkout the paper that the TriLM is part of, it compares "normal" float based lms with eg. TriLMs.

And TriLMs seem to provide the best performance per bit.

image
(higher is better, lefter is better)
(they all look something like this)

https://blog.nolano.ai/Spectra-suite/

https://huggingface.co/papers/2407.12327

ggerganov
ggerganov358 days ago

Note: Offloading tq2_0 layers to vram(cuda) improved the time by ~20%. It was still 10x slower than q4_k though.

@Green-Sky There is no CUDA implementation yet, so there is no need to benchmark CUDA yet - it will be slow because the data will be moved to the CPU all the time

mirek190
mirek190358 days ago

Yes I understand that . I'm curious how good such models can be in creativity and reasoning trained such way from the beginning.

You can checkout the paper that the TriLM is part of, it compares "normal" float based lms with eg. TriLMs.

And TriLMs seem to provide the best performance per bit.

image (higher is better, lefter is better) (they all look something like this)

https://blog.nolano.ai/Spectra-suite/

https://huggingface.co/papers/2407.12327

I hope you are right ... because that is not working with a Diffusion models , there 8 bit is still quite close to original fp16 but lowering anything below is visually degrading quality.

Green-Sky
Green-Sky358 days ago

Note: Offloading tq2_0 layers to vram(cuda) improved the time by ~20%. It was still 10x slower than q4_k though.

@Green-Sky There is no CUDA implementation yet, so there is no need to benchmark CUDA yet - it will be slow because the data will be moved to the CPU all the time

While that is true, I did observe an improvement. Any tips on what this could be? One of my speculations is that my ram is so slow, that PCI-e is faster, but that would be very funny.

model size params backend thrds/ngl test t/s
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CPU 12 pp512 146.17 ± 0.71
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CUDA 0 pp512 143.04 ± 0.79
llama ?B TQ2_0 - 2.06 bpw ternary 1.08 GiB 3.99 B CUDA 99 pp512 155.20 ± 2.79
ggerganov
ggerganov358 days ago👀 1

Likely the KV cache ops in the attention (which are still in F16xF32 format) are much faster with CUDA and compensate for the PCI-e transfer overhead

BarfingLemurs
BarfingLemurs358 days ago👍 2👀 5

@mirek190
https://github.com/Lucky-Lance/TerDiT
These are diffusion models trained with ternary weights.

Most LLM quantization is optimized for LLM performance in terms of accuracy (GPTQ, GGUF). There is nothing like this in pytorch, what you might have seen/tried could be naive int4 quantization.

mirek190
mirek190357 days ago

@mirek190 https://github.com/Lucky-Lance/TerDiT These are diffusion models trained with ternary weights.

Most LLM quantization is optimized for LLM performance in terms of accuracy (GPTQ, GGUF). There is nothing like this in pytorch, what you might have seen/tried could be naive int4 quantization.

... That's crazy

So is possible something like a q4 good quality diffusion model ?
That would open diffusion community for a big models because right now absolutely limit for home PC is 12b model like Flux and rtx 3090.

compilade ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot
96b3d411
Hugi-R
Hugi-R353 days ago❤ 2🚀 6

I noticed that TriLM-99M TQ2_0 would fit in the L3 cache of my R7 5700X3D. So I tried it, and the result are impressive! Great work!

.\build\bin\Release\llama-bench.exe -m ..\llm\TriLM-99M-TQ2_0.gguf -p 1500 -n 500 -t 8,12,16
model size params backend threads test t/s
llama ?B TQ2_0 - 2.06 bpw ternary 45.89 MiB 99.76 M CPU 8 pp1500 1960.98 ± 35.75
llama ?B TQ2_0 - 2.06 bpw ternary 45.89 MiB 99.76 M CPU 8 tg500 786.31 ± 14.72
llama ?B TQ2_0 - 2.06 bpw ternary 45.89 MiB 99.76 M CPU 16 pp1500 2511.44 ± 65.75
llama ?B TQ2_0 - 2.06 bpw ternary 45.89 MiB 99.76 M CPU 16 tg500 605.09 ± 32.08

GPU for comparison:

.\llama\b3505\llama-bench.exe -m .\llm\TriLM-99M-Q8_0.gguf -p 1500 -n 500 -ngl 99

Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes

model size params backend ngl test t/s
llama ?B Q8_0 101.13 MiB 99.76 M CUDA 99 pp1500 46232.14 ± 796.75
llama ?B Q8_0 101.13 MiB 99.76 M CUDA 99 tg500 715.95 ± 12.58

A GPU beaten at token generation by a CPU, and with a much faster cold start 🤩

compilade Merge branch 'master' into compilade/bitnet-ternary
d911cd1f
compilade gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0
3a0bf17d
compilade convert : allow direct conversion to TQ1_0 and TQ2_0
895004f3
compilade ggml-quants : allow using ARM dot product instructions for TQ1_0
69f77268
compilade Merge branch 'master' into compilade/bitnet-ternary
82b24040
compilade ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support
35cc5567
compilade compilade marked this pull request as ready for review 349 days ago
ggerganov
ggerganov approved these changes on 2024-08-15
basavyr
basavyr341 days ago👀 1

I am trying to test the TriLM_3.9B_Unpacked with both TQ1_0 and TQ2_0 quants. Reading this discussion, I see that these two quantization methods are still supported on TriLM models (as opposed to the abandoned quantization for BitNet).

Using this exact pull request, I am building llama.cpp on a MacBook M3 Pro. The straightforward make -j n build command should build with Metal support by default (source). After building llama.cpp with success, I am firstly converting the HF model of TriLM_3.9B_Unpacked to f16-GGUF format, then finally quantizing with llama-quantize to the aforementioned formats. Everything works fine up until here.

The issue comes when I am trying to perform inference on the Apple GPU:

/path_to_built_llama/llama_cli -m quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "hey there"
Log start
main: build = 3610 (35cc5567)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.6.0
main: seed  = 1724234806
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                         general.size_label str              = 4.0B
llama_model_loader: - kv   3:                            general.license str              = apache-2.0
llama_model_loader: - kv   4:                          llama.block_count u32              = 30
llama_model_loader: - kv   5:                       llama.context_length u32              = 2048
llama_model_loader: - kv   6:                     llama.embedding_length u32              = 3072
llama_model_loader: - kv   7:                  llama.feed_forward_length u32              = 9216
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 24
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 24
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:         llama.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  13:                          general.file_type u32              = 37
llama_model_loader: - kv  14:                           llama.vocab_size u32              = 50688
llama_model_loader: - kv  15:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  17:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  18:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  19:                      tokenizer.ggml.tokens arr[str,50688]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,50688]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  21:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  22:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   61 tensors
llama_model_loader: - type q4_K:    1 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type tq2_0:  210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50688
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_layer          = 30
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 24
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 3072
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 9216
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params     = 3.99 B
llm_load_print_meta: model size       = 1.08 GiB (2.33 BPW)
llm_load_print_meta: general.name     = n/a
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  1027.47 MiB, ( 1027.55 / 12288.02)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors:      Metal buffer size =  1027.46 MiB
llm_load_tensors:        CPU buffer size =    83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Pro
ggml_metal_init: picking default device: Apple M3 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M3 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 12884.92 MB
llama_kv_cache_init:      Metal KV buffer size =   720.00 MiB
llama_new_context_with_model: KV self size  =  720.00 MiB, K (f16):  360.00 MiB, V (f16):  360.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.19 MiB
llama_new_context_with_model:      Metal compute buffer size =   124.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    10.01 MiB
llama_new_context_with_model: graph nodes  = 966
llama_new_context_with_model: graph splits = 2
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
[1]    36927 abort      /Users/basavyr/Repos/external/llama.cpp/llama-cli -m  -p "hey there"

This error does not occur with the GPU inference explicitly disabled, via the --n-gpu-layers|-ngl 0 flag.

Q: Am I missing something ? Did anyone else try to test this on M1/2/3 GPUs?

flatsiedatsie
flatsiedatsie341 days ago🚀 1

@basavyr Could you share the quantified files on Huggingface? Then I'll happily give it a try on my Macbook Pro M1.

sorasoras
sorasoras341 days ago

I am trying to test the TriLM_3.9B_Unpacked with both TQ1_0 and TQ2_0 quants. Reading this discussion, I see that these two quantization methods are still supported on TriLM models (as opposed to the abandoned quantization for BitNet).

Using this exact pull request, I am building llama.cpp on a MacBook M3 Pro. The straightforward make -j n build command should build with Metal support by default (source). After building llama.cpp with success, I am firstly converting the HF model of TriLM_3.9B_Unpacked to f16-GGUF format, then finally quantizing with llama-quantize to the aforementioned formats. Everything works fine up until here.

The issue comes when I am trying to perform inference on the Apple GPU:

/path_to_built_llama/llama_cli -m quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "hey there"
Log start
main: build = 3610 (35cc5567)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.6.0
main: seed  = 1724234806
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from quants/TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                         general.size_label str              = 4.0B
llama_model_loader: - kv   3:                            general.license str              = apache-2.0
llama_model_loader: - kv   4:                          llama.block_count u32              = 30
llama_model_loader: - kv   5:                       llama.context_length u32              = 2048
llama_model_loader: - kv   6:                     llama.embedding_length u32              = 3072
llama_model_loader: - kv   7:                  llama.feed_forward_length u32              = 9216
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 24
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 24
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:         llama.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  13:                          general.file_type u32              = 37
llama_model_loader: - kv  14:                           llama.vocab_size u32              = 50688
llama_model_loader: - kv  15:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  17:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  18:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  19:                      tokenizer.ggml.tokens arr[str,50688]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,50688]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  21:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  22:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   61 tensors
llama_model_loader: - type q4_K:    1 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type tq2_0:  210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50688
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_layer          = 30
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 24
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 3072
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 9216
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params     = 3.99 B
llm_load_print_meta: model size       = 1.08 GiB (2.33 BPW)
llm_load_print_meta: general.name     = n/a
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  1027.47 MiB, ( 1027.55 / 12288.02)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors:      Metal buffer size =  1027.46 MiB
llm_load_tensors:        CPU buffer size =    83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M3 Pro
ggml_metal_init: picking default device: Apple M3 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M3 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 12884.92 MB
llama_kv_cache_init:      Metal KV buffer size =   720.00 MiB
llama_new_context_with_model: KV self size  =  720.00 MiB, K (f16):  360.00 MiB, V (f16):  360.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.19 MiB
llama_new_context_with_model:      Metal compute buffer size =   124.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    10.01 MiB
llama_new_context_with_model: graph nodes  = 966
llama_new_context_with_model: graph splits = 2
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
ggml/src/ggml-metal.m:1619: MUL MAT-MAT not implemented
[1]    36927 abort      /Users/basavyr/Repos/external/llama.cpp/llama-cli -m  -p "hey there"

This error does not occur with the GPU inference explicitly disabled, via the --n-gpu-layers|-ngl 0 flag.

Q: Am I missing something ? Did anyone else try to test this on M1/2/3 GPUs?

I don't think TQ packing support GPU inference yet

compilade
compilade341 days ago👍 3

I don't think TQ packing support GPU inference yet

It does not (yet). But for at least TQ2_0, it should be possible to adapt the Metal kernels from ikawrakow/ik_llama.cpp#13.

I'll see what I can do, but I can't test Metal kernels directly, so I'll likely postpone full support to a follow-up pull-request.

basavyr
basavyr340 days ago (edited 340 days ago)👍 2

@flatsiedatsie The two quantized models are available here. Feel free to try them :)

However, as per @sorasoras and @ggerganov, we might have to wait until support for Metal inference is officially confirmed.

@ggerganov I will give that a try and see if it works.

Thanks a lot guys for the support 🎉🙏


Edit: In the meantime, I have managed to perform Metal quantization (IQ2_TN) + inference on the same TriLM variant. You can play around with the .gguf added in this HF commit. This was possible through the PR that Georgi mentioned.

flatsiedatsie
flatsiedatsie340 days ago (edited 340 days ago)👍 2

Thanks for sharing the .gguf files @basavyr!

I ran a succesful test using Wllama, so this is 100% browser-based BitNet (running on the CPU):

Screenshot 2024-08-22 at 22 35 10
compilade Merge branch 'master' into compilade/bitnet-ternary
cb6d9962
compilade
compilade340 days ago (edited 340 days ago)

@basavyr

Can you test whether https://github.com/compilade/llama.cpp/tree/compilade/bitnet-ternary-metal allows you to run TQ2_0 models on Metal? (this is another branch)

  • Does it compile?
  • Does the output looks correct?
  • Is it faster than when not using Metal?

For the Metal kernels, I've mostly used the code from ikawrakow/ik_llama.cpp#13, but like I said, I can't test it directly because I don't have Apple hardware. If it does not work, then I'll leave that unimplemented here because debugging over comments would not really be convenient.

Also, I am not Georgi :)

flatsiedatsie
flatsiedatsie340 days ago (edited 340 days ago)

It runs and looks correct.

6) Update Your Website Design:
You need to update your website design on a regular basis in order to keep your website fresh and relevant. This includes adding new content
llama_print_timings:        load time =    5112.49 ms
llama_print_timings:      sample time =      13.70 ms /   400 runs   (    0.03 ms per token, 29197.08 tokens per second)
llama_print_timings: prompt eval time =     125.92 ms /    15 tokens (    8.39 ms per token,   119.13 tokens per second)
llama_print_timings:        eval time =    7764.58 ms /   399 runs   (   19.46 ms per token,    51.39 tokens per second)
llama_print_timings:       total time =    7939.94 ms /   414 tokens
ggml_metal_free: deallocating

FULL LOG

./llama-cli -m ./TriLM_3.9B_Unpacked_quant_TQ2_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e
Log start
main: build = 3647 (2f5e28f9)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.5.0
main: seed  = 1724390564
llama_model_loader: loaded meta data with 28 key-value pairs and 273 tensors from ./TriLM_3.9B_Unpacked_quant_TQ2_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                         general.size_label str              = 4.0B
llama_model_loader: - kv   3:                            general.license str              = apache-2.0
llama_model_loader: - kv   4:                          llama.block_count u32              = 30
llama_model_loader: - kv   5:                       llama.context_length u32              = 2048
llama_model_loader: - kv   6:                     llama.embedding_length u32              = 3072
llama_model_loader: - kv   7:                  llama.feed_forward_length u32              = 9216
llama_model_loader: - kv   8:                 llama.attention.head_count u32              = 24
llama_model_loader: - kv   9:              llama.attention.head_count_kv u32              = 24
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  12:         llama.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  13:                          general.file_type u32              = 37
llama_model_loader: - kv  14:                           llama.vocab_size u32              = 50688
llama_model_loader: - kv  15:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  16:            tokenizer.ggml.add_space_prefix bool             = false
llama_model_loader: - kv  17:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  18:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  19:                      tokenizer.ggml.tokens arr[str,50688]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,50688]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  21:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  22:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   61 tensors
llama_model_loader: - type q4_K:    1 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type tq2_0:  210 tensors
llm_load_vocab: special tokens cache size = 25
llm_load_vocab: token to piece cache size = 0.2984 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 50688
llm_load_print_meta: n_merges         = 50009
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 3072
llm_load_print_meta: n_layer          = 30
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 24
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 3072
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 9216
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = TQ2_0 - 2.06 bpw ternary
llm_load_print_meta: model params     = 3.99 B
llm_load_print_meta: model size       = 1.08 GiB (2.33 BPW) 
llm_load_print_meta: general.name     = n/a
llm_load_print_meta: BOS token        = 0 '<|endoftext|>'
llm_load_print_meta: EOS token        = 0 '<|endoftext|>'
llm_load_print_meta: UNK token        = 0 '<|endoftext|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 0 '<|endoftext|>'
llm_load_print_meta: max token length = 1024
llm_load_tensors: ggml ctx size =    0.26 MiB
ggml_backend_metal_log_allocated_size: allocated buffer, size =  1027.47 MiB, ( 1027.53 / 10922.67)
llm_load_tensors: offloading 30 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 31/31 layers to GPU
llm_load_tensors:      Metal buffer size =  1027.46 MiB
llm_load_tensors:        CPU buffer size =    83.53 MiB
....................................................................................
llama_new_context_with_model: n_ctx      = 2048
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M1 Pro
ggml_metal_init: picking default device: Apple M1 Pro
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name:   Apple M1 Pro
ggml_metal_init: GPU family: MTLGPUFamilyApple7  (1007)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction support   = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 11453.25 MB
llama_kv_cache_init:      Metal KV buffer size =   720.00 MiB
llama_new_context_with_model: KV self size  =  720.00 MiB, K (f16):  360.00 MiB, V (f16):  360.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.19 MiB
llama_new_context_with_model:      Metal compute buffer size =   124.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    10.01 MiB
llama_new_context_with_model: graph nodes  = 966
llama_new_context_with_model: graph splits = 2

system_info: n_threads = 6 / 8 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 0 | NEON = 1 | SVE = 0 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 2048, n_batch = 2048, n_predict = 400, n_keep = 0


Building a website can be done in 10 simple steps:
Step 1: Create your website domain.
Step 2: Create your website content.
Step 3: Create your website pages.
Step 4: Create your website design.
Step 5: Build your website on your hosting server.
Step 6: Update your website design.
Step 7: Create your website login and password.
Step 8: Build your website with the help of an expert.
Step 9: Test your website and enjoy it.
Step 10: Build your website with the help of an expert.
If you want to build a website and make it successful, then follow these simple steps and you will be able to build a great website.
1) Create Your Website Domain:
The first step is to create your own website domain. You can use any domain name you like, but it should be something that's relevant to your business. You can also use a website builder like Weebly or Wix, which will help you create your website quickly and easily.
2) Create Your Website Content:
You need to create your website content in order to make it successful. This includes writing articles, blogging, and creating videos.
3) Create Your Website Pages:
Once you have created your website content, you need to create your website pages. This includes creating pages for your blog, product page, and other sections.
4) Create Your Website Design:
You need to create your website design in order to make it successful. This includes choosing a design that's relevant to your business and using a professional web design company to help you.
5) Build Your Website on Your Hosting Server:
Once you have created your website design, you need to build your website on your hosting server. This is the most important step because it is the foundation of your website.
6) Update Your Website Design:
You need to update your website design on a regular basis in order to keep your website fresh and relevant. This includes adding new content
llama_print_timings:        load time =    5112.49 ms
llama_print_timings:      sample time =      13.70 ms /   400 runs   (    0.03 ms per token, 29197.08 tokens per second)
llama_print_timings: prompt eval time =     125.92 ms /    15 tokens (    8.39 ms per token,   119.13 tokens per second)
llama_print_timings:        eval time =    7764.58 ms /   399 runs   (   19.46 ms per token,    51.39 tokens per second)
llama_print_timings:       total time =    7939.94 ms /   414 tokens
ggml_metal_free: deallocating
Log end

ggerganov
ggerganov339 days ago👍 2

For the Metal kernels, I've mostly used the code from ikawrakow/ik_llama.cpp#13, but like I said, I can't test it directly because I don't have Apple hardware

@compilade Don't worry about the Metal implementation. I can add this in a separate PR

basavyr
basavyr332 days ago (edited 328 days ago)🚀 1

@compilade Sorry for the late answer...

I have managed to compile your fork of llama.cpp and successfully run the TQ2_0 inference on Metal for SpectraSuite/TriLM_3.9B_Unpacked. It looks like inference with Metal acceleration for TriLM_3.9B_Unpacked model is now possible ✅

Moreover, I have also tried to quantize all three versions of bitnet models from 🤗 (i.e., 3B, xl, and large) with the same branch, however it is not working ❌ (as expected)

Answering your questions:

  • Does it compile? -> YES
  • Does the output looks correct? -> If by this you mean the model output during inference, It is not able to provide good prompt responses.
  • Is it faster than when not using Metal? See results below👇

GPU:

llama_print_timings:        load time =     209.00 ms
llama_print_timings:      sample time =       4.94 ms /   256 runs   (    0.02 ms per token, 51874.37 tokens per second)
llama_print_timings: prompt eval time =      64.34 ms /     7 tokens (    9.19 ms per token,   108.80 tokens per second)
llama_print_timings:        eval time =    3476.64 ms /   255 runs   (   13.63 ms per token,    73.35 tokens per second)
llama_print_timings:       total time =    3558.48 ms /   262 tokens
ggml_metal_free: deallocating

CPU (--n-gpu-layers 0):

llama_print_timings:        load time =     112.92 ms
llama_print_timings:      sample time =       5.99 ms /   256 runs   (    0.02 ms per token, 42766.46 tokens per second)
llama_print_timings: prompt eval time =      78.54 ms /     7 tokens (   11.22 ms per token,    89.13 tokens per second)
llama_print_timings:        eval time =    4509.33 ms /   255 runs   (   17.68 ms per token,    56.55 tokens per second)
llama_print_timings:       total time =    4608.16 ms /   262 tokens
flatsiedatsie
flatsiedatsie327 days ago👍 4

Is there any reason to not merge this? I'm already using it in a project built on Wllama, and now I have to take extra steps to compile it each time.

Don't let the perfect be the enemy of the good.

compilade Merge branch 'master' into compilade/bitnet-ternary
7f3a619c
compilade ggml ; remove unused ggml_mul special case
8d616076
compilade test-backend-ops : add TQ1_0 and TQ2_0 comments for later
75b3a096
compilade compilade force pushed from e4dc48a5 to 75b3a096 327 days ago
compilade
compilade327 days ago🎉 5❤ 2🚀 2

Is there any reason to not merge this?

Not really. It's pretty much ready (apart from support in other backends than CPU-only, and quantization to Q4_K and Q6_K not being supported in gguf-py yet, but I guess this can be fixed later once reference quantization is made platform independent (ref #8939 (comment))). I did have an hesitation with the order of the values in TQ1_0, but after making some experiments with indexing, its current structure (which ended up unchanged) should be good enough for future GPU implementations (hopefully).

(the indexing experiment)

Indices to extract 4 values per tid. This relies a lot on memory access coalescing, since each 5 consecutive tid will read the same 4 bytes (the last 4 bytes are only read by 4 tid, though).

This tests the read order.

tq1_0_pattern: list[list[int]] = [[32*v + b for v in range(5)] for b in range(32)] + [[16*v + b + 160 for v in range(5)] for b in range(16)] + [[4*v + b + 240 for v in range(4)] for b in range(4)]

print(f"{tq1_0_pattern=}")

for tid in range(64):
    n = tid // 40;  # 0 or 1
    nh = tid // 60; # 0 or 1
    il = tid // 5;  # 0..13
    ir = tid % 5;   # 0..5
    l = 32 - 16*n - 12*nh; # 32, 16 or 4

    q: list[int] = [4*il + j for j in range(4)]
    y: list[int] = [128*n + 64*nh + 4*il + l*ir + j for j in range(4)];

    status = "good"
    for a, b in zip(q, y):
        if a >= len(tq1_0_pattern) or b != tq1_0_pattern[a][ir]:
            status = "bad"
            break

    print(f"{tid=}, {q=}, {y=}, {status=}")

Support for other backends than CPU will be added in separate pull requests. The only non-CPU backends I can possibly implement (and test) with the hardware I have are Vulkan (and CUDA, but only in evenings and weekends). TQ2_0 should be relatively easy to port everywhere, while TQ1_0 will require more effort, but should still be implementable.

Don't let the perfect be the enemy of the good.

Right. And the recent commits (8d61607 and 75b3a09) should not be controversial (reducing the changes to ggml_mul, handling missing TQ1_0 and TQ2_0 cases in switch statements for some ggml operators with a dequantization-based variant (ggml_add, ggml_add1, etc.), and adding some comments in tests/test-backend-ops.cpp about TQ1_0 and TQ2_0).

I will merge this soon, either today or tomorrow if I forget.

compilade compilade added merge ready
compilade compilade merged 9bc6db28 into master 326 days ago
WenguoLi
WenguoLi317 days ago👀 2

Thank you very much for this great job. Do you have any plans to further support risc-v devices?

rhjdvsgsgks
rhjdvsgsgks76 days ago

There might be a way to make a similar type as TQ2_0 like some sort of Q2_1, which could be almost as fast but still usable by non-ternary models, but this will probably require something like LQER to help with keeping some precision.

someone discovered a way to fine tune a non bitnet model to bitnet while keeping precision by adding a rmsnorm layer (huggingface/transformers#38087). is that helpful to make tq2_0 avaliable to non bitnet/ternary model?

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone