How to pack ternary numbers in 8-bit bytes
with efficient SIMD-friendly unpacking
There are 3 possible values in a digit of a ternary number. 3 possible values, which could actually be anything.
I've been recently nerd-sniped1 into trying to pack the ternary weights of BitNet b1.58 into something close to that theoretical ideal of log(3) / log(2)
bits2 per ternary digit.
I'll be calling a "ternary digit" a "trit", like a "binary digit" is called a "bit".
Block size
Since the goal of this is to allow fast parallel unpacking, blocks of trits can't be infinitely big. A small "block" size needs to be found, ideally one which is both efficient with information density and which is convenient on current hardware.
To find a good block size, we'll need to find a power of 3 for which the next power of 2 is very close.
trits | 3trits | bits | 2bits | bits per trit |
---|---|---|---|---|
1 | 3 | 2 | 4 | 2 |
2 | 9 | 4 | 16 | 2 |
3 | 27 | 5 | 32 | 1.666... |
4 | 81 | 7 | 128 | 1.75 |
5 | 243 | 8 | 256 | 1.6 |
It's very fortunate that 5 trits fit quite tight into 8 bits at 1.6 bits
per trit.
When compared to perfect packing, this is 99.06%
efficient.
1.6 bits per trit
The basic idea with this packing scheme is simply to make a number out of the ternary digits.
def pack_number(digits: list[int], base: int) -> int:
number = 0
for digit in digits:
assert digit < base
number = number * base
number = number + digit
return number
Packing trits into bytes should be similar enough.
Fast multiplication unpacking
While repeated remainder and divisions can be used to extract the digits of a number, the problem with divisions and modulo is that they are not usually supported on integers in SIMD programming.
A way around this is obviously to view numbers differently.
Would it be nice if instead of extracting the least significant digit with modulo, we could extract the most significant digit with a multiplication?
Fixed point numbers to the rescue!
Tada!
Now digits can be easily extracted from the top two bits of the resulting 10-bit number when multiplying this 8-bit byte by 3.
This is much more convenient than modulo when unpacking with SIMD.
The only place where there are divisions in this scheme when packing trits into bytes. This assumes that packing is done less often than unpacking, which is very true in the context of LLM weights.
# Take a list of values in -1, 0, 1 and pack them in bytes
def pack_trits(digits: list[int]) -> bytearray:
assert len(digits) % 5 == 0 # padding isn't handled here
n_bytes = len(digits) // 5
packed = bytearray()
for i in range(n_bytes):
b = 0
for j in range(5):
digit = digits[5*i + j]
digit = max(-1, min(digit, 1)) # clamp between -1 and 1
digit += 1 # from -1, 0, 1 to 0, 1, 2
b *= 3
b += digit
b = ((b * 256) + (243 - 1)) // 243
packed.append(b)
return packed
The relevant interesting line is this one:
b = ((b * 256) + (243 - 1)) // 243
It does what is depicted in the diagram above, but multiplication is done first because these are integer operations. Doing a ceiling division here is necessary to cancel the off-by-one error from truncating when extracting digits later.
To unpack without using the modulo operator:
def unpack_trits(packed: bytes) -> list[int]:
trits: list[int] = []
for byte in packed:
b = byte
for i in range(5):
b = b * 3
trit = b >> 8
trits.append(trit - 1) # 0, 1, 2 => -1, 0, 1
b = b & 0xFF
return trits
To convince myself that this works, I wrote a C program checking that this really is lossless:
#include <stdint.h>
#include <stdio.h>
#include <string.h>
int main(void) {
char s1[6] = {0};
char s2[6] = {0};
for (uint8_t i = 0; i < 243; ++i) {
uint8_t n = i;
// Get the number representation in base 3
// by repeatedly extracting the least significant digit with modulo
for (int j = 5; j-- > 0;) {
s1[j] = (n % 3) + '0';
n /= 3;
}
// Turn that number into a fixed-point number smaller than 1
uint8_t q = (((uint16_t) i) * 256 + (243 - 1)) / 243;
// This extracts the most significant digit first
for (int j = 0; j < 5; ++j) {
uint16_t m = q * 3;
s2[j] = (m >> 8) + '0';
q = m & 0xFF;
}
printf("%s, %s: %s\n", s1, s2, strcmp(s1, s2) == 0 ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m");
}
return 0;
}
Compile and run with:
$ gcc ternary-packing.c -o ternary-packing
$ ./ternary-packing
And I'm getting PASS
for each of the 243 ternary numbers which fit in 8 bits.
And this is the technique used in the ternary types in llama.cpp
for TriLMs and BitNet b1.58, for which the pull request is https://github.com/ggerganov/llama.cpp/pull/8151, with SIMD implementations for both AVX2 and ARM NEON.
-
obviously referring to https://xkcd.com/356/, but the initial motivation actually started from this review comment I posted on the initial BitNet b1.58 pull-request for
llama.cpp
↩ -
log(3) / log(2)
is also known as1.584962500721156
. ↩