Wonderful job! I'm wondering if this PR can merge into the master branch, it would be so good if users of llama.cpp can use Q2_2 and Q1_3 conveniently.
26 | 26 | { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, | |
27 | 27 | { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, | |
28 | 28 | { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, | |
29 | { "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet b1.58", }, | ||
30 | { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", }, |
Regarding the names of the new quant types, since these are quite specific to BitNet models, I was thinking to name them with something starting with QB
, a bit like suggested in #5761 (comment).
I'll first be describing what I want from the naming scheme, then I'll attempt to make it work.
The naming scheme should have room for:
{-1, 0, 1}
1.625 bpw
quant with a block size of 64, with 13 bytes per block
Q8_0
as its vec_dot_type
(for the activations)float16
scale in the leftover bits in the last byte of 16 consecutive blocks (this means 1024 elements minimum per row), although it can't really be extracted with SIMD)2.000 bpw
quant with a block size of 32, with 8 bytes per block
Q8_0
as its vec_dot_type
(for the activations)2.000 bpw
quant with a block size of 64, with 16 bytes per block, and a float16 scale
1.625 bpw
type, but with an extra byte and a row-wise float16
scale duplicated in each block.2.000 bpw
quant with a block size of 4, with 1 byte per block
vec_dot_type
{-1, 1}
1 bpw
type{0, 1}
8.5 bpw
like Q8_0
, but all the scales of a row are the same
float32
operations in the vec_dot of the above types.10 bpw
, 5 bytes per block of 4 elements, with a weird layout which only uses blocks to get a big enough buffer, with a single float32 scale and some padding before all row elements, aligned and contiguous.
2.000 bpw
type, and also maybe the other ones for best performance.So the naming scheme could be:
QB<x>_<y>
<x>
is the floor of the expected bpw of the type<y>
is
0
binary type, {0, 1}
QB8_0
which is like Q8_0
but with a guaranteed duplicated row-wise scale1
binary type, {-1, 1}
2
ternary type using some kind of binary-coded ternary3
ternary type with fixed-point packed values4
weird type with a block size of 4Which for the previously-mentioned possible BitNet types would mean:
proposed name | Range | bits per weight | block size | bytes | row-wise scale | current name |
---|---|---|---|---|---|---|
QB1_3 |
{-1, 0, 1} |
1.625 | 64 | 13 | 1.0f | Q1_3 |
QB2_2 |
{-2, -1, 0, 1} |
2.000 | 32 | 8 | 1.0f | Q2_2 |
QB2_3 |
{-1, 0, 1} |
2.000 | 64 | 16 | f16 |
|
QB2_4 |
{-2, -1, 0, 1} |
2.000 | 4 | 1 | 1.0f | |
QB1_1 |
{-1, 1} |
1.000 | ? | ?/8 | 1.0f | |
QB1_0 |
{0, 1} |
1.000 | ? | ?/8 | 1.0f | |
QB8_0 |
[-127, 127] |
8.5 | 32 | 34 | f16 |
|
QB8_4 |
[-127, 127] |
10 | 4 | 5 | f32 , weird layout |
I'm not saying these should all exist, though, only that the naming scheme should not be too limiting for possible future extensions (which might not exist anyway due to lack of time).
So I think I'll rename Q1_3
to QB1_3
, and Q2_2
to QB2_2
. Anyone has comments on this? Or a better naming scheme for the new BitNet quant types?
If it were me, considering this only works with bitnet models and nothing else, I'd want the designations to be exceptionally clear that they are different and shouldn't be used on just anything. "QB" is good, but I'd take it a step further and remove the Q entirely. As bitnet is being colloquially referred to as a "1-bit" model, B1 makes more sense. Considering the plausible range for weights, I'd cut it off at tenths and ditch the decimal. This leaves plenty of room for variations, while making the native BPW very clear. I feel this is superior to the arbitrary "_2" and "_3" subtypes.
So what I would propose is:
1.625bpw = B1_16
2.000bpw = B1_20
@compilade and @Eddie-Wang1120 continuing the race to the bottom 🥳 , glorious.
Did some quick testing with the 3B model and it looks very good.
model | size | params | backend | threads | test | t/s |
---|---|---|---|---|---|---|
bitnet 3B Q1_3 - 1.625 bpw for BitNet b1.58 | 729.64 MiB | 3.32 B | BLAS | 12 | pp512 | 78.40 ± 0.27 |
bitnet 3B Q1_3 - 1.625 bpw for BitNet b1.58 | 729.64 MiB | 3.32 B | BLAS | 12 | tg128 | 38.16 ± 0.04 |
bitnet 3B Q2_2 - 2.000 bpw for BitNet b1.58 | 873.65 MiB | 3.32 B | BLAS | 12 | pp512 | 73.35 ± 6.23 |
bitnet 3B Q2_2 - 2.000 bpw for BitNet b1.58 | 873.65 MiB | 3.32 B | BLAS | 12 | tg128 | 36.86 ± 0.12 |
What surprises me a little, after reading about q2_2
being faster, is that q1_3
seems to be faster with the setup I used here. Will investigate further.
edit: also updated the files at https://huggingface.co/Green-Sky/bitnet_b1_58-3B-GGUF , for anyone else willing to test.
Did a bit of testing myself, it runs and generates well but unfortunately it's the undertrained models rather than our implementation that's holding back BitNet adoption. For me Q1_3 is slower but this computer is CPU rather than memory bound.
model | size | params | backend | threads | test | t/s |
---|---|---|---|---|---|---|
bitnet 3B Q1_3 - 1.625 bpw for BitNet 1.58b | 729.64 MiB | 3.32 B | CPU | 4 | pp512 | 15.15 ± 0.07 |
bitnet 3B Q1_3 - 1.625 bpw for BitNet 1.58b | 729.64 MiB | 3.32 B | CPU | 4 | tg128 | 9.87 ± 0.65 |
bitnet 3B Q2_2 - 2.000 bpw for BitNet 1.58b | 873.65 MiB | 3.32 B | CPU | 4 | pp512 | 19.25 ± 0.44 |
bitnet 3B Q2_2 - 2.000 bpw for BitNet 1.58b | 873.65 MiB | 3.32 B | CPU | 4 | tg128 | 13.07 ± 0.28 |
bitnet 3B Q4_0 | 1.79 GiB | 3.32 B | CPU | 4 | pp512 | 18.44 ± 0.40 |
bitnet 3B Q4_0 | 1.79 GiB | 3.32 B | CPU | 4 | tg128 | 5.87 ± 0.12 |
I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants so that we can handle more than 64 bits of Q2_2 quants in each dot product loop. Aside from that I can't find any further way to improve that AVX implementation, and while it's ironic that we're using a madds
instruction there when BitNet technically doesn't require multiplication that looks like the fastest way to dot the activations and ternary weights.
I wonder if Q2_2 could be made faster if we used a block size of say 256 like the K-quants
Can't go with bigger blocks than 64 elements or else the 3B model won't be fully quantizable. (Its FFN size is 8640 (which factors into 2 2 2 2 2 2 3 3 3 5
))
Its current block size is 32, which is the same as its vec_dot_type, Q8_0.
What would also help with performance would be to somehow use an 8-bit vec_dot_type having a single float scale per row. Might be interesting to explore later, but ggml
does not have row-wise quant types yet, although this could still be done with a block quant.
it's ironic that we're using a
madds
instruction
Yeah, with AVX2, there are no good widening addition instructions like on ARM NEON, so _mm256_maddubs_epi16
is used for that.
Meanwhile, NEON doesn't have the equivalent of _mm_sign_epi8
, so it needs to use multiplications or conditional masks, which are both slower than a dedicated instruction doing zeroing and sign flipping like in SSSE3.
672 | int8_t x2 = (int8_t)x[i*qk + 2*qk/4 + j]; | ||
673 | int8_t x3 = (int8_t)x[i*qk + 3*qk/4 + j]; | ||
674 | |||
675 | const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3; | ||
676 | const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3; | ||
677 | const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3; | ||
678 | const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3; |
As proposed, the type utilizes only 3 of the 4 possible values. I was thinking that the Q2_2
type would work the same as Q4_0
, but assumes amax == 1.0f
:
void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
static const int qk = QK2_2;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}
// assume amax = 1.0f
max /= amax;
const float d = max / -2;
const float id = d ? 1.0f/d : 0.0f;
for (int j = 0; j < qk/4; ++j) {
const float x0 = x[i*qk + 0*qk/4 + j]*id;
const float x1 = x[i*qk + 1*qk/4 + j]*id;
const float x2 = x[i*qk + 2*qk/4 + j]*id;
const float x3 = x[i*qk + 3*qk/4 + j]*id;
const uint8_t xi0 = MIN(3, (int8_t)(x0 + 2.5f));
const uint8_t xi1 = MIN(3, (int8_t)(x1 + 2.5f));
const uint8_t xi2 = MIN(3, (int8_t)(x2 + 2.5f));
const uint8_t xi3 = MIN(3, (int8_t)(x3 + 2.5f));
y[i].qs[j] = xi0;
y[i].qs[j] |= xi1 << 2;
y[i].qs[j] |= xi2 << 4;
y[i].qs[j] |= xi3 << 6;
}
}
}
(not tested, just pattern matching the existing quantize_row_q4_0_reference()
)
Edit: just realized the above would not work. We have assume that max == 1.0f
, not amax
, so:
const float max = 1.0f;
...
Whew, it has been a month since I last touched this, I got distracted for a bit.
(tl;dr at the end)
Now that new ternary models like TriLMs exist (https://arxiv.org/abs/2407.12327), which use multiple scales per tensors and which (fortunately) have all tensor dimensions divisible by 256 🎉, I think I should add a ternary type with 256 elements per block and a block-wise f16
scale. That would result in 1.6875 bpw
, which sounds very reasonable to me.
Another ternary type with a scale but with a smaller block size (64) might be useful for compatibility with the BitNet b1.58 models from the 1bitLLM team (because their model dimensions are not divisible by 256), and would be 1.875 bpw
or 2.0 bpw
depending on whether padding 15 bytes of data to 16 bytes is better for performance.
These should have a similar inference speed as Q1_3
, since they will use a similar packing scheme.
I'm not sure if it's worth it to keep the scale-less ternary quant types; I feel like they require too much special handling in the model graphs and in the convert script. It might be okay for BitNetForCausalLM
, but not for some newer models like TriLMs which use LlamaForCausalLM
, AKA not a ternary-specific architecture.
So I'll be proposing 4 (starting with 2) types, with yet another attempt at a naming scheme1 for ternary quants,
this time matching the regex TQ\d(_\dF?)?
:
TQ1_0
1.6875 bpw
.Q1_3
, but repeated 4 times, and with a f16
scale.vec_dot_type
could be Q8_K
TQ1_0F
1.875 bpw
or 2.0 bpw
float
, but at least it should give a vague idea that this is slightly bigger than TQ1_0
.Q1_3
, but with a f16
scale.vec_dot_type
will be Q8_0
TQ2_0
2.0625 bpw
.Q2_2
, so it should be performant, unless on platforms where the misalignment from the 2 bytes of the scale has some effect.vec_dot_type
could be Q8_K
IQ2_XSS
, which can't even represent 0 unless the whole block is 0.TQ2_0F
TQ1_0F
, but based on Q2_2
instead of Q1_3
.2.25 bpw
Note that IQ2_XXS
is already a 256-element type with similar properties as TQ2_0
, although IQ2_XXS
's packing scheme is much more complicated and I feel like its reliance on iq2xxs_grid
makes it unnecessarily slower than it could be.2
I'll work on at least TQ1_0
and TQ1_0F
in the next days, but I might get distracted. I'm doing this as a hobby in my free time, so it's possible that my priorities shift depending on external factors. This means anyone interested should feel free to ping me if I seem to have forgotten this again.
TL;DR: I think I'll replace the scale-less Q1_3
and Q2_2
with ternary types with a block-wise scale, which should allow supporting both BitNet b1.58 and TriLMs, while also simplifying the conversion for BitNet b1.58 because separate scale tensors won't be needed anymore.
Some rationale for the naming scheme: using a special prefix to note that these are special-purpose, TQ
stands for "ternary quant", not using QT to avoid confusion with https://www.qt.io/, and also because the IQ
quants also prefix Q
with a letter. I'm using _0
as suffix to mean that it has a scale similarly to Q8_0
. ↩
Okay, I've read a bit about IQ2_XXS
, and it seems slightly over-engineered and totally not intended for ternary models. Basically, it strongly relies on a lookup table (iq2xxs_grid
), which contains 3 possible values in each byte: 0x08
, 0x19
or 0x2b
(8, 25, 45, respectively). This looks like where the absolute values of the elements comes from (before being scaled and signed). This means 0
is not representable unless the whole block is 0
. ↩
@compilade Keep up the good work. You are a hero making living on the edge affordable 😄 .
Beside the others here of course... 😉
Not sure if anyone has noticed, but meta(facebook) changed the license for llama3.1 to allow training on outputs, which would allow for distillation.
So now I am waiting to see a bitnet distillation of the new 3.1 llamas pop up (hopefully).
@compilade btw quick question regarding the packing structure of these encoding arrangements. Is there a consistent way to extract the bit pattern structure from the source code? It's a bit hard to grok the superblock, blocks and how bits are being packed for documentation. Ideally too I would like such documentation to be autogenerated as well, but until I can understand the basics from the C struct... it's a bit hard to get started.
The plan sounds good. I wouldn't worry about the fallback types - we already have a workaround via padding for such kind of models, plus I doubt there will be much of those in the future.
we already have a workaround via padding for such kind of models
@ggerganov While it mostly works, padding like in e9f2abf isn't correct with ggml_rms_norm
, because the row size is used to calculate the mean.
To make padding work properly, there would need to be some special handling to make it possible to use ne[0]
values which are not multiples of the block size (like making ggml_row_size
round up).
The GGUF file format should already support that, since the tensor offsets don't directly depend on their size.
But GGUFWriter
would need to avoid assuming a lossless round-trip between shape and byte shape.
Quantization and dequantization would need to be adapted, because the functions currently assume ne[0]
is a multiple of the block size. But the quantize_row_*_ref
functions don't necessarily know ne[0]
directly (they get the total element count in a chunk of rows), but that should be easy enough to adapt with doing one call per row when padding is needed, a bit like applying importance matrices is done one row at a time. Or padding could be handled outside, but this would (momentarily) use more memory for the padded f32
copies (unpadding can be done with views).
Dot products would need no change if the padding values are equivalent to zero (this won't work for IQ2_XXS
and likely other IQ types which can't represent zero).
I wouldn't worry about the fallback types
Understood. I agree with adding fewer types. And using padding could even let the cursed https://huggingface.co/1bitLLM/bitnet_b1_58-xl be quantized with its weird FFN size of 5460 which factors into 2 2 3 5 7 13
.
I'll start with not handling padding, because it would affect other types too (notably Q8_K
), and might be more appropriate in a separate PR.
Is there a consistent way to extract the bit pattern structure from the source code? It's a bit hard to grok the superblock, blocks and how bits are being packed for documentation. Ideally too I would like such documentation to be autogenerated as well, but until I can understand the basics from the C struct... it's a bit hard to get started.
No, unfortunatley, I don't think this can be easily automated. Sometimes a single field in the structs stores multiple types of values, like in Q4_K
where block_q4_K.scales
stores 6-bit scales and mins in some pattern1. The easiest way to understand what the bits mean is to have a look at the respective dequantize_row
function of each type.
The 12 bytes in Q4_K .scales
are packed a bit like this, where the uppercased letters are bits for the scales and lowercased letters are the bits of the mins:
0: EEAAAAAA
1: FFBBBBBB
2: GGCCCCCC
3: HHDDDDDD
4: eeaaaaaa
5: ffbbbbbb
6: ggcccccc
7: hhdddddd
8: eeeeEEEE
9: ffffFFFF
10: ggggGGGG
11: hhhhHHHH
Source: https://github.com/ggerganov/llama.cpp/blob/75af08c475e285888f66556d0f459c533b7deb95/ggml/src/ggml-quants.c#L1891-L1898 ↩
@compilade thanks for the explanation it is interesting to see that the bits are split into 2bit and 4bits and uses only bitwise actions. Is this because it's preferred over packing each 6bit scale in sequential order, because each access is aligned or is cheaper to use bitwise operations?
edit: Ah likely to be more friendlier for parallel processing in gpu etc...
@ggerganov While it mostly works, padding like in e9f2abf isn't correct with ggml_rms_norm, because the row size is used to calculate the mean
Correct, the norm can be applied on a view having the original size though (1D tensors used for normalisations are never quantised).
I've made some preliminary performance (speed) tests with TQ1_0
and TQ2_0
, and TQ1_0
is faster than Q1_3
, now around the speed of Q8_0
, while TQ2_0
got a very big perf boost and is twice as fast as TQ1_0
, which makes it by far the fastest quant type (around 2x faster than Q8_0
, and 1.7x faster than Q2_K
)1, at least with AVX2 on my machine. Bigger block sizes do pay off!
(And Q8_K
is a very good vec_dot_type
, with a f32
scale and even pre-computed sums)
Note that this is about the vec_dot
speed and not the overall speed, although it's usually where most of the compute time is spent.
The formats of TQ1_0
and TQ2_0
are a bit different than what I initially planned, to make the data more convenient to access in the AVX2 vec_dot
. Something nice is that unlike Q1_3
, TQ1_0
does not rely on reading past the buffer (Q1_3
has 13 byte blocks which were read in 16 byte chunks).
A possible future improvement for the AVX2 vec_dot of TQ1_0
would be to test if 16-bit multiplies and permutes are faster or not than more elaborate ways to shift 8-bit values by powers of 3 (AVX2 does not have non-widening 8-bit multiplies), but both approaches were mostly similar in performance on my machine, so I went with the 8-bit operations.
I'll port TQ1_0
and TQ2_0
to ARM NEON in the next days, and I'll remove Q1_3
and Q2_2
after making comparisons on low-end ARM devices.
Is this because it's preferred over packing each 6bit scale in sequential order, because each access is aligned or is cheaper to use bitwise operations?
@mofosyne I had no part in the decision of the scale packing in Q4_K
, but I think it's like this because indexing is only done at the byte level, so packing and unpacking 6-bit values has to use bitwise operations. Pointers can only jump at a minimum of a byte at a time. Also when making the vec_dot
of Q1_3
I've noticed that shuffles are surprisingly as fast as additions in SIMD.
Proof:
test-quantize-perf
(click to expand)$ for t in q4_0 q8_0 q4_K q2_K tq2_0 tq1_0 q1_3 q2_2; do ./bin/test-quantize-perf --op vec_dot_q --type $t -i 10000000; done
q4_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 3.03
avg cycles/32 vals : 3.33
float32 throughput : 52.88 GB/s
quantized throughput : 7.44 GB/s
q8_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 2.24
avg cycles/32 vals : 2.51
float32 throughput : 68.26 GB/s
quantized throughput : 18.13 GB/s
q4_K
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 2.22
avg cycles/32 vals : 2.68
float32 throughput : 62.68 GB/s
quantized throughput : 8.81 GB/s
q2_K
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 1.75
avg cycles/32 vals : 1.99
float32 throughput : 81.82 GB/s
quantized throughput : 6.71 GB/s
tq2_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.83
avg cycles/32 vals : 0.95
float32 throughput : 144.50 GB/s
quantized throughput : 9.31 GB/s
tq1_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 2.11
avg cycles/32 vals : 2.29
float32 throughput : 71.35 GB/s
quantized throughput : 3.76 GB/s
q1_3
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 2.94
avg cycles/32 vals : 3.46
float32 throughput : 50.02 GB/s
quantized throughput : 2.54 GB/s
q2_2
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 2.12
avg cycles/32 vals : 2.33
float32 throughput : 73.31 GB/s
quantized throughput : 4.58 GB/s
(some discussion around compilade's improvement can be found on Reddit)
I've tested that a round-trip quantization between TQ1_0
and TQ2_0
is lossless, which means one can always be made from the other.
$ ./build/bin/llama-quantize models/trilm-390M-f16.gguf models/trilm-390M-tq1_0.gguf tq1_0
$ ./build/bin/llama-quantize models/trilm-390M-f16.gguf models/trilm-390M-tq2_0.gguf tq2_0
$ ./build/bin/llama-quantize --allow-requantize models/trilm-390M-tq1_0.gguf models/trilm-390M-tq2_0-requant.gguf tq2_0
$ ./build/bin/llama-quantize --allow-requantize models/trilm-390M-tq2_0-requant.gguf models/trilm-390M-tq1_0-roundtrip.gguf tq1_0
$ cd models
$ sha256sum trilm-390M-tq*
e4c622fb10dcfa30d427eb94eb08ffdcbde8ef3683a2b43a1b1eac8ab6e3e67f trilm-390M-tq1_0.gguf
e4c622fb10dcfa30d427eb94eb08ffdcbde8ef3683a2b43a1b1eac8ab6e3e67f trilm-390M-tq1_0-roundtrip.gguf
4edaaa33f8d7ffeaac72d758bf0e253512128a4a872a9c428bf337abb21a64be trilm-390M-tq2_0.gguf
4edaaa33f8d7ffeaac72d758bf0e253512128a4a872a9c428bf337abb21a64be trilm-390M-tq2_0-requant.gguf
I've also added ARM NEON implementations of vec_dot
for TQ1_0
and TQ2_0
, but the relative speedup on a Raspberry Pi 4 B
is less impressive than with AVX2 on my laptop. There might still be ways to optimize the use of ARM NEON in there.
Still, it's decent, at 1.6x the speed of Q8_0
for TQ2_0
. But the RPi4 is very memory bound (with a bandwidth only around 3GB/s), so actual inference speed is relatively much better with smaller types.
But I'm happy that TQ1_0
is 1.75x as fast as Q1_3
on that machine. The gap between TQ1_0
and TQ2_0
is also smaller than with AVX2.
test-quantize-perf
on a RPi4 (click to expand)$ for t in q4_0 q8_0 q4_K q2_K tq2_0 tq1_0 q1_3 q2_2; do ./bin/test-quantize-perf --op vec_dot_q --type $t -i 2000000; done
q4_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 7.82 GB/s
quantized throughput : 1.10 GB/s
q8_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 9.57 GB/s
quantized throughput : 2.54 GB/s
q4_K
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 9.38 GB/s
quantized throughput : 1.32 GB/s
q2_K
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 9.64 GB/s
quantized throughput : 0.79 GB/s
tq2_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 15.35 GB/s
quantized throughput : 0.99 GB/s
tq1_0
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 11.82 GB/s
quantized throughput : 0.62 GB/s
q1_3
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 6.75 GB/s
quantized throughput : 0.34 GB/s
q2_2
vec_dot_q
4096 values (0.02 MB)
min cycles/32 vals : 0.00
avg cycles/32 vals : 0.00
float32 throughput : 9.14 GB/s
quantized throughput : 0.57 GB/s
The next steps are to remove Q1_3
and Q2_2
, and to adapt the convert script to let it convert directly to at least one of TQ1_0
or TQ2_0
.
I saw compilade remove the old bitnet quants, so I decided it was time for another round of tests.
Since the large bitnet repro model does not work with the new quants (as explained in the OP), I switched to the TriLM_3.9B
model.
quant | ppl | ppl@300 | filesize |
---|---|---|---|
f16 | 11.1532 +/- 0.07854 | 11.0180 | 7.5G |
q8_0 | 11.1489 +/- 0.07851 | 11.015 | 4.0G |
q4_0 | 11.4797 +/- 0.08058 | 11.3249 | 2.2G |
q4_k | 11.1559 +/- 0.07854 | 11.0223 | 2.3G |
tq2_0 | 11.1558 +/- 0.07853 | 11.0200 | 1.1G |
tq1_0 | 11.1558 +/- 0.07853 | 11.0200 | 949M |
I added ppl at step 300 for reference to speed up future ppl calculations.
Note: Offloading tq2_0
layers to vram(cuda) improved the time by ~20%. It was still 10x slower than q4_k though.
As always I used default settings calculating perplexity over 560 chunks, n_ctx=512, batch_size=2048, n_seq=4
edit: uploaded some quantized files again https://huggingface.co/Green-Sky/TriLM_3.9B-GGUF
Login to write a write a comment.
This adds
1.6875 bpw
and2.0625 bpw
quant types for TriLMs and BitNet b1.58 models. For now, these are namedTQ1_0
andTQ2_0
, respectively.I had given glimpses of this idea starting from #7931 (comment).
The
1.6875 bpw
type mostly relies on the fact that35 == 243 < 256 == 28
to pack 5 trits per byte.(I also made a blog post about ternary packing in an attempt to explain the core idea a bit more (storing the values in fixed-point to extract the most significant digit first with multiplications))
Huge thanks to @Eddie-Wang1120, who motivated this by adding initial BitNet b1.58 support in #7931.
How to try it
Using TriLM models is the easiest because all of their models have row sizes divisible by 256.
Important
To quantize the token embeddings and the output tensor to
Q4_K
andQ6_K
, you need to usellama-quantize
on the model files produced byconvert_hf_to_gguf.py --outtype tq1_0
(and also fortq2_0
). Otherwise these two tensors are kept asf16
and are responsible for most of the size of the models.If you want to try
TQ2_0
, which is faster (but bigger) thanTQ1_0
on compute-bound hardware, you can replacetq1_0
withtq2_0
in the above example, but it's also possible to convert from theTQ1_0
model file.The two ternary formats hold the same values, so round-trip quantizing between the two should result in the same files.
$ ./build/bin/llama-quantize --allow-requantize /somewhere/TriLM-3.9B-TQ1_0.gguf /somewhere/TriLM-3.9B-TQ2_0.gguf tq2_0
Speed
TQ2_0
is twice as fast asQ4_K
on my laptop. It's the fastest quant on compute-bound AVX2-capable computers.This is a table of the float32-equivalent throughput of the
vec_dot_q
operation for each of these quant types.t4g
(NEON)t4g
(DOTPROD)From this, it's easy to see that
TQ1_0
is usually slightly faster thanQ4_K
, and thatTQ2_0
is by far the fastest quant on AVX2.Note
There might be a way to make a similar type as
TQ2_0
like some sort ofQ2_1
, which could be almost as fast but still usable by non-ternary models, but this will probably require something like LQER to help with keeping some precision.Raw data (click to expand)
Intel Core m3-8100Y:
Arm Cortex A72 (Raspberry Pi 4):
Arm Cortex A53 (Some Android phone from 2017):
AWS
t4g.small
instance (Arm Neoverse N1) using NEON:AWS
t4g.small
(Arm Neoverse N1) with-march=native
:Size
The token embeddings are kept at
Q4_K
and the output projection atQ6_K
, which means the smaller models might be slightly bigger than 2 bits per weight.All of the TriLM models should work, because their row sizes are multiples of 256. I did not try them all yet, but those I tried are in the table below.
The BitNet b1.58 models from the 1bitLLM team however are not all compatible; only the 700M model has dimensions divisible by 256. The others are not supported (yet), unless when padding them.
Note
The 1.3B BitNet b1.58 model has a FFN size of 5460 which factors into
2 2 3 5 7 13
, which is not convenient for any block-wise types based on powers of 2, so these tensors are kept asF16
. My hypothesis is that 5460 was a typo for 5440 (factors into2 2 2 2 2 2 5 17
), but it was kept for some reason, and reproduced by the 1bitLLM team. If anyone training ternary models reads this, PLEASE DON'T USE5460
FOR THE FFN SIZE! Please use multiples of 256 for your row sizes.Perplexity
Quality seems good. I don't have a powerful machine, so my tests only include the first 16 chunks of wikitext-2-raw with https://huggingface.co/SpectraSuite/TriLM_390M_Unpacked.
The tests below use
Q4_K
token embeddings andQ6_K
output tensor forTQ1_0
andTQ2_0
, whileF16
token embeddings and output tensor is used inTQ1_0_L
andTQ2_0_L
.From this it seems like there is no significant quality loss for the ternary quants for TriLM models (I think the difference with pure
f16
comes from the 8-bit activations), and thatTQ1_0
andTQ2_0
are completely equivalent in quality (and they should be, because lossless conversion between the two is possible).Structure of
TQ1_0
This type relies on the fact that
3^5 == 243 < 256 == 2^8
.In a block of 256 elements, there are 240 elements encoded in 5 elements per byte, while the last 16 elements are encoded in 4 elements per byte.
This means
(240 / 5) + (16 / 4) == 48 + 4 == 52
bytes are used to pack 256 ternary weights (this is1.625
bits per weight).But there is also one
float16
scale per block, so the size of a block is 54 bytes making it a1.6875 bpw
type. Even though it's not ideal, this is still1.6875 / (log(3) / log(2)) ≈ 94%
of the best ternary packing efficiency.In the table below I'm describing the order of the elements within the bytes. I'm using ranges to make this shorter, with the notation
start..end
where the start is inclusive and the end is exclusive. (So0..3
is{0, 1, 2}
)Read this as if the ranges of a row are zipped together. A byte never contains more than 5 ternary values.
The ternary values are stored unsigned, so
{-1, 0, 1}
is stored as{0, 1, 2}
.0..32
0..32
32..64
64..96
96..128
128..160
32..48
160..176
176..192
192..208
208..224
224..240
48..52
240..244
244..248
248..252
252..256
And then byte
52
and53
contain thefloat16
scale in little-endian.Values are stored in fixed point to allow extracting the most significant digit first. This is explained in https://compilade.net/blog/ternary-packing.
Structure of
TQ2_0
This type was originally inspired by the
Q2_2
type made by @Eddie-Wang1120, but the block size, the order, and the mapping of the values are different.TQ2_0
started as an experiment to see how fast a 2-bit type can be compared to a 1.6-bit type on compute-bound hardware.This packs each ternary value in 2 bits, which means each byte contains 4 values.
The ternary values are stored unsigned, so
{-1, 0, 1}
is stored as{0, 1, 2}
.Again, the ranges use the
start..end
notation where the start is inclusive and the end is exclusive, and the ranges of a row should be read as being zipped together (they advance in parallel in lockstep).0..32
96..128
64..96
32..64
0..32
32..64
224..256
192..224
160..192
128..160
And then byte
64
and65
contain thefloat16
scale in little-endian.TODO
TQ1_0
andTQ2_0
convert_hf_to_gguf.py
to directly convert a ternary model to a ternary encodingf16
for the token embeddings and output tensor becauseQ4_K
andQ6_K
quantization is not yet supported bygguf-py
. This meansllama-quantize
needs to be used to quantize these tensors.llama-quantize
afterwards.TQ1_0_L
or something?float16
scale should be before or after the packed weightsQ2_K
,Q3_K
andQ6_K
) is to keep the scale before.llama-quantize
Q4_0
as a fallback type, because the smallest symmetric quant type isQ8_0
but it's a bit big, soQ4_0
it is (even though it's not ideal). Only relevant when row sizes are not multiples of 256.__ARM_FEATURE_DOTPROD
variants of the dot products ofTQ1_0
andTQ2_0
with their bare__ARM_NEON
variants to reduce code duplication.TQ1_0
andTQ2_0
for correctness on an ARM CPU which supports dot product instructionst4g.small
instance.TQ1_0
's first 48 bytes be divided in 3 sub-blocks of 16 bytes (80 elements) instead of one of 32 bytes (140 elements) and one of 16 bytes?Q1_3
andQ2_2
TQ1_0
andTQ2_0
.ggml_mul
when the broadcasted tensor only has a single element