How to pack ternary numbers in 8-bit bytes

with efficient SIMD-friendly unpacking


Published:

There are 3 possible values in a digit of a ternary number. 3 possible values, which could actually be anything.

-1 0 1 0 1 2

I've been recently nerd-sniped1 into trying to pack the ternary weights of BitNet b1.58 into something close to that theoretical ideal of log(3) / log(2) bits2 per ternary digit.

I'll be calling a "ternary digit" a "trit", like a "binary digit" is called a "bit".

Block size

Since the goal of this is to allow fast parallel unpacking, blocks of trits can't be infinitely big. A small "block" size needs to be found, ideally one which is both efficient with information density and which is convenient on current hardware.

To find a good block size, we'll need to find a power of 3 for which the next power of 2 is very close.

trits 3trits bits 2bits bits per trit
1 3 2 4 2
2 9 4 16 2
3 27 5 32 1.666...
4 81 7 128 1.75
5 243 8 256 1.6

It's very fortunate that 5 trits fit quite tight into 8 bits at 1.6 bits per trit. When compared to perfect packing, this is 99.06% efficient.

1.6 bits per trit

The basic idea with this packing scheme is simply to make a number out of the ternary digits.

def pack_number(digits: list[int], base: int) -> int:
    number = 0

    for digit in digits:
      assert digit < base

      number = number * base
      number = number + digit

    return number

Packing trits into bytes should be similar enough.

Fast multiplication unpacking

While repeated remainder and divisions can be used to extract the digits of a number, the problem with divisions and modulo is that they are not usually supported on integers in SIMD programming.

A way around this is obviously to view numbers differently.

Would it be nice if instead of extracting the least significant digit with modulo, we could extract the most significant digit with a multiplication?

Fixed point numbers to the rescue!

0x7F. 11201. .11201 0x0.86 0x86. same number divide by 243 same number, round up multiply by 256

Tada!

Now digits can be easily extracted from the top two bits of the resulting 10-bit number when multiplying this 8-bit byte by 3.

This is much more convenient than modulo when unpacking with SIMD.

The only place where there are divisions in this scheme when packing trits into bytes. This assumes that packing is done less often than unpacking, which is very true in the context of LLM weights.

# Take a list of values in -1, 0, 1 and pack them in bytes
def pack_trits(digits: list[int]) -> bytearray:
    assert len(digits) % 5 == 0  # padding isn't handled here

    n_bytes = len(digits) // 5
    packed = bytearray()

    for i in range(n_bytes):
        b = 0
        for j in range(5):
            digit = digits[5*i + j]
            digit = max(-1, min(digit, 1))  # clamp between -1 and 1
            digit += 1  # from -1, 0, 1 to 0, 1, 2
            b *= 3
            b += digit

        b = ((b * 256) + (243 - 1)) // 243

        packed.append(b)

    return packed

The relevant interesting line is this one:

        b = ((b * 256) + (243 - 1)) // 243

It does what is depicted in the diagram above, but multiplication is done first because these are integer operations. Doing a ceiling division here is necessary to cancel the off-by-one error from truncating when extracting digits later.

To unpack without using the modulo operator:

def unpack_trits(packed: bytes) -> list[int]:
    trits: list[int] = []

    for byte in packed:
        b = byte
        for i in range(5):
            b = b * 3
            trit = b >> 8
            trits.append(trit - 1)  # 0, 1, 2 => -1, 0, 1
            b = b & 0xFF

    return trits

To convince myself that this works, I wrote a C program checking that this really is lossless:

#include <stdint.h>
#include <stdio.h>
#include <string.h>

int main(void) {
    char s1[6] = {0};
    char s2[6] = {0};

    for (uint8_t i = 0; i < 243; ++i) {
        uint8_t n = i;
        // Get the number representation in base 3
        // by repeatedly extracting the least significant digit with modulo
        for (int j = 5; j-- > 0;) {
            s1[j] = (n % 3) + '0';
            n /= 3;
        }
        // Turn that number into a fixed-point number smaller than 1
        uint8_t q = (((uint16_t) i) * 256 + (243 - 1)) / 243;
        // This extracts the most significant digit first
        for (int j = 0; j < 5; ++j) {
            uint16_t m = q * 3;
            s2[j] = (m >> 8) + '0';
            q = m & 0xFF;
        }
        printf("%s, %s: %s\n", s1, s2, strcmp(s1, s2) == 0 ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m");
    }

    return 0;
}

Compile and run with:

$ gcc ternary-packing.c -o ternary-packing
$ ./ternary-packing

And I'm getting PASS for each of the 243 ternary numbers which fit in 8 bits.

And this is the technique used in the upcoming 1.625 bpw quant in llama.cpp for BitNet b1.58, for which the pull request is https://github.com/ggerganov/llama.cpp/pull/8151, with SIMD implementations for both AVX2 and ARM NEON.


  1. obviously referring to https://xkcd.com/356/, but the initial motivation actually started from this review comment I posted on the initial BitNet b1.58 pull-request for llama.cpp 

  2. log(3) / log(2) is also known as 1.584962500721156