Forked from veritatisquaesitoressumus/ggml_turboquant.c
Created
April 20, 2026 22:22
-
-
Save renesugar/de24b851e8b263851b9c3a8f0adbbea1 to your computer and use it in GitHub Desktop.
TurboQuant KV Cache Compression for llama.cpp (Zandieh et al., ICLR 2026) — 3-bit, 4.9x compression, 18/18 tests passing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * TurboQuant: CPU Reference Implementation | |
| * ========================================== | |
| * Implements Algorithm 1 (TurboQuant_mse) from Zandieh et al., ICLR 2026. | |
| * | |
| * This is the portable C implementation that runs on CPU. | |
| * The CUDA implementation (ggml_turboquant.cu) mirrors this logic | |
| * with GPU-optimized kernels. | |
| * | |
| * Authors: Jim Sullivan / Claude collaboration | |
| * Date: 2026-03-25 | |
| */ | |
| #include "ggml_turboquant.h" | |
| #include <math.h> | |
| #include <string.h> | |
| #include <float.h> | |
| /* ========================================================================= | |
| * Section 1: Seeded RNG (xoshiro256** for reproducible rotation matrices) | |
| * | |
| * We need a portable, high-quality PRNG that produces identical sequences | |
| * across platforms given the same seed. xoshiro256** fits perfectly. | |
| * ========================================================================= */ | |
| typedef struct { | |
| uint64_t s[4]; | |
| } tq_rng; | |
| static inline uint64_t tq_rng_rotl(uint64_t x, int k) { | |
| return (x << k) | (x >> (64 - k)); | |
| } | |
| static uint64_t tq_rng_next(tq_rng * rng) { | |
| const uint64_t result = tq_rng_rotl(rng->s[1] * 5, 7) * 9; | |
| const uint64_t t = rng->s[1] << 17; | |
| rng->s[2] ^= rng->s[0]; | |
| rng->s[3] ^= rng->s[1]; | |
| rng->s[1] ^= rng->s[2]; | |
| rng->s[0] ^= rng->s[3]; | |
| rng->s[2] ^= t; | |
| rng->s[3] = tq_rng_rotl(rng->s[3], 45); | |
| return result; | |
| } | |
| static void tq_rng_seed(tq_rng * rng, uint64_t seed) { | |
| /* SplitMix64 to expand a single seed into 4 state words */ | |
| for (int i = 0; i < 4; i++) { | |
| seed += 0x9e3779b97f4a7c15ULL; | |
| uint64_t z = seed; | |
| z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL; | |
| z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL; | |
| rng->s[i] = z ^ (z >> 31); | |
| } | |
| } | |
| /* Convert uint64 to standard normal via Box-Muller */ | |
| static float tq_rng_normal(tq_rng * rng) { | |
| /* Generate two uniform [0,1) values */ | |
| double u1 = (double)(tq_rng_next(rng) >> 11) / (double)(1ULL << 53); | |
| double u2 = (double)(tq_rng_next(rng) >> 11) / (double)(1ULL << 53); | |
| /* Clamp to avoid log(0) */ | |
| if (u1 < 1e-15) u1 = 1e-15; | |
| return (float)(sqrt(-2.0 * log(u1)) * cos(2.0 * 3.14159265358979323846 * u2)); | |
| } | |
| /* ========================================================================= | |
| * Section 2: QR Decomposition (Modified Gram-Schmidt) | |
| * | |
| * Generates the random orthogonal rotation matrix Π from Algorithm 1, Line 2. | |
| * We fill a d×d matrix with standard normals, then orthogonalize via | |
| * modified Gram-Schmidt. The result is stored row-major in ctx->rotation. | |
| * ========================================================================= */ | |
| static void tq_generate_rotation(float * R, int d, uint64_t seed) { | |
| tq_rng rng; | |
| tq_rng_seed(&rng, seed); | |
| /* Fill with standard normals (row-major: R[i*d + j]) */ | |
| for (int i = 0; i < d * d; i++) { | |
| R[i] = tq_rng_normal(&rng); | |
| } | |
| /* Modified Gram-Schmidt orthogonalization (column-wise) */ | |
| /* Column j of R is at R[row*d + j] for row in [0, d) */ | |
| for (int j = 0; j < d; j++) { | |
| /* Subtract projections of column j onto all previous columns */ | |
| for (int k = 0; k < j; k++) { | |
| float dot = 0.0f; | |
| float norm_k_sq = 0.0f; | |
| for (int i = 0; i < d; i++) { | |
| dot += R[i * d + j] * R[i * d + k]; | |
| norm_k_sq += R[i * d + k] * R[i * d + k]; | |
| } | |
| if (norm_k_sq > 1e-15f) { | |
| float scale = dot / norm_k_sq; | |
| for (int i = 0; i < d; i++) { | |
| R[i * d + j] -= scale * R[i * d + k]; | |
| } | |
| } | |
| } | |
| /* Normalize column j */ | |
| float norm = 0.0f; | |
| for (int i = 0; i < d; i++) { | |
| norm += R[i * d + j] * R[i * d + j]; | |
| } | |
| norm = sqrtf(norm); | |
| if (norm > 1e-15f) { | |
| float inv_norm = 1.0f / norm; | |
| for (int i = 0; i < d; i++) { | |
| R[i * d + j] *= inv_norm; | |
| } | |
| } | |
| } | |
| } | |
| /* ========================================================================= | |
| * Section 3: Context Initialization | |
| * ========================================================================= */ | |
| int tq_context_init(tq_context * ctx, int bits, uint64_t seed) { | |
| if (!ctx) return -1; | |
| if (bits < 2 || bits > 4) return -1; | |
| ctx->d = TQ_HEAD_DIM; | |
| ctx->bits = bits; | |
| ctx->n_levels = 1 << bits; | |
| /* Select codebook */ | |
| switch (bits) { | |
| case 2: ctx->codebook = TQ_CODEBOOK_2; break; | |
| case 3: ctx->codebook = TQ_CODEBOOK_3; break; | |
| case 4: ctx->codebook = TQ_CODEBOOK_4; break; | |
| default: return -1; | |
| } | |
| /* Generate rotation matrix */ | |
| tq_generate_rotation(ctx->rotation, ctx->d, seed); | |
| return 0; | |
| } | |
| /* ========================================================================= | |
| * Section 4: Bit-packing | |
| * | |
| * Pack/unpack arrays of b-bit values into/from byte arrays. | |
| * For b=3: 128 values × 3 bits = 384 bits = 48 bytes | |
| * For b=4: 128 values × 4 bits = 512 bits = 64 bytes | |
| * ========================================================================= */ | |
| void tq_pack_indices(const uint8_t * indices, uint8_t * packed, | |
| int n_values, int bits) { | |
| memset(packed, 0, (n_values * bits + 7) / 8); | |
| int bit_pos = 0; | |
| for (int i = 0; i < n_values; i++) { | |
| uint8_t val = indices[i]; | |
| for (int b = 0; b < bits; b++) { | |
| if (val & (1 << b)) { | |
| packed[bit_pos / 8] |= (1 << (bit_pos % 8)); | |
| } | |
| bit_pos++; | |
| } | |
| } | |
| } | |
| void tq_unpack_indices(const uint8_t * packed, uint8_t * indices, | |
| int n_values, int bits) { | |
| int bit_pos = 0; | |
| uint8_t mask = (1 << bits) - 1; | |
| for (int i = 0; i < n_values; i++) { | |
| uint8_t val = 0; | |
| for (int b = 0; b < bits; b++) { | |
| if (packed[bit_pos / 8] & (1 << (bit_pos % 8))) { | |
| val |= (1 << b); | |
| } | |
| bit_pos++; | |
| } | |
| indices[i] = val & mask; | |
| } | |
| } | |
| /* ========================================================================= | |
| * Section 5: Quantize — Algorithm 1 (TurboQuant_mse) | |
| * ========================================================================= */ | |
| void tq_quantize(const tq_context * ctx, const float * src, void * dst) { | |
| const int d = ctx->d; | |
| const int bits = ctx->bits; | |
| const int n_levels = ctx->n_levels; | |
| const float * codebook = ctx->codebook; | |
| const float * R = ctx->rotation; | |
| /* Step 1: Compute L2 norm */ | |
| float norm = 0.0f; | |
| for (int i = 0; i < d; i++) { | |
| norm += src[i] * src[i]; | |
| } | |
| norm = sqrtf(norm); | |
| /* Temporary buffers (stack-allocated, d=128 so this is fine) */ | |
| float y[TQ_HEAD_DIM]; /* Rotated vector */ | |
| uint8_t indices[TQ_HEAD_DIM]; /* Codebook indices */ | |
| if (norm < 1e-15f) { | |
| /* Zero vector — store zero norm and zero indices */ | |
| if (bits == 3) { | |
| block_tq3 * blk = (block_tq3 *)dst; | |
| blk->norm = 0.0f; | |
| memset(blk->indices, 0, TQ3_INDEX_BYTES); | |
| } else if (bits == 4) { | |
| block_tq4 * blk = (block_tq4 *)dst; | |
| blk->norm = 0.0f; | |
| memset(blk->indices, 0, TQ4_INDEX_BYTES); | |
| } | |
| return; | |
| } | |
| float inv_norm = 1.0f / norm; | |
| /* Step 2-3: Normalize and rotate: y = Π · (x / ||x||) */ | |
| /* R is row-major: R[i*d + j] = element (i,j) */ | |
| /* y[i] = sum_j R[i*d + j] * x_unit[j] */ | |
| for (int i = 0; i < d; i++) { | |
| float sum = 0.0f; | |
| const float * row = R + i * d; | |
| for (int j = 0; j < d; j++) { | |
| sum += row[j] * src[j] * inv_norm; | |
| } | |
| y[i] = sum; | |
| } | |
| /* Step 4: Find nearest codebook centroid for each coordinate */ | |
| for (int i = 0; i < d; i++) { | |
| float best_dist = FLT_MAX; | |
| uint8_t best_idx = 0; | |
| for (int c = 0; c < n_levels; c++) { | |
| float dist = (y[i] - codebook[c]) * (y[i] - codebook[c]); | |
| if (dist < best_dist) { | |
| best_dist = dist; | |
| best_idx = (uint8_t)c; | |
| } | |
| } | |
| indices[i] = best_idx; | |
| } | |
| /* Step 5: Pack into output block */ | |
| if (bits == 3) { | |
| block_tq3 * blk = (block_tq3 *)dst; | |
| blk->norm = norm; | |
| tq_pack_indices(indices, blk->indices, d, bits); | |
| } else if (bits == 4) { | |
| block_tq4 * blk = (block_tq4 *)dst; | |
| blk->norm = norm; | |
| tq_pack_indices(indices, blk->indices, d, bits); | |
| } | |
| } | |
| /* ========================================================================= | |
| * Section 6: Dequantize — Algorithm 1 Inverse | |
| * ========================================================================= */ | |
| void tq_dequantize(const tq_context * ctx, const void * src, float * dst) { | |
| const int d = ctx->d; | |
| const int bits = ctx->bits; | |
| const float * codebook = ctx->codebook; | |
| const float * R = ctx->rotation; | |
| float norm; | |
| const uint8_t * packed; | |
| if (bits == 3) { | |
| const block_tq3 * blk = (const block_tq3 *)src; | |
| norm = blk->norm; | |
| packed = blk->indices; | |
| } else if (bits == 4) { | |
| const block_tq4 * blk = (const block_tq4 *)src; | |
| norm = blk->norm; | |
| packed = blk->indices; | |
| } else { | |
| memset(dst, 0, d * sizeof(float)); | |
| return; | |
| } | |
| if (fabsf(norm) < 1e-15f) { | |
| memset(dst, 0, d * sizeof(float)); | |
| return; | |
| } | |
| /* Step 1: Unpack indices */ | |
| uint8_t indices[TQ_HEAD_DIM]; | |
| tq_unpack_indices(packed, indices, d, bits); | |
| /* Step 2: Map indices to centroid values -> y_hat */ | |
| float y_hat[TQ_HEAD_DIM]; | |
| for (int i = 0; i < d; i++) { | |
| y_hat[i] = codebook[indices[i]]; | |
| } | |
| /* Step 3: Rotate back: x_hat = Π^T · y_hat */ | |
| /* Π^T row i, col j = R[j*d + i] (transpose of row-major R) */ | |
| for (int i = 0; i < d; i++) { | |
| float sum = 0.0f; | |
| for (int j = 0; j < d; j++) { | |
| sum += R[j * d + i] * y_hat[j]; | |
| } | |
| /* Step 4: Scale by original norm */ | |
| dst[i] = sum * norm; | |
| } | |
| } | |
| /* ========================================================================= | |
| * Section 7: Batch Operations | |
| * ========================================================================= */ | |
| void tq_quantize_batch(const tq_context * ctx, const float * src, | |
| void * dst, int n_vectors) { | |
| const int d = ctx->d; | |
| size_t blk_size = tq_block_size(ctx->bits); | |
| for (int v = 0; v < n_vectors; v++) { | |
| tq_quantize(ctx, src + v * d, (uint8_t *)dst + v * blk_size); | |
| } | |
| } | |
| void tq_dequantize_batch(const tq_context * ctx, const void * src, | |
| float * dst, int n_vectors) { | |
| const int d = ctx->d; | |
| size_t blk_size = tq_block_size(ctx->bits); | |
| for (int v = 0; v < n_vectors; v++) { | |
| tq_dequantize(ctx, (const uint8_t *)src + v * blk_size, dst + v * d); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * TurboQuant: CUDA Kernel Implementation | |
| * ======================================== | |
| * GPU-accelerated quantize/dequantize for KV cache vectors. | |
| * | |
| * Architecture notes for RTX 3090 (SM 8.6, Ampere): | |
| * - 128 threads per block = 4 warps = good occupancy | |
| * - Each thread handles one coordinate of the head_dim=128 vector | |
| * - Rotation matrix loaded into shared memory (64 KB fits easily) | |
| * - Codebook in constant memory (max 64 floats = 256 bytes) | |
| * | |
| * The key insight: TurboQuant's rotation step is a matrix-vector multiply | |
| * (Π · x), which maps naturally to one thread per output element with | |
| * shared-memory reduction. For d=128 this is one thread per coordinate. | |
| * | |
| * Authors: Jim Sullivan / Claude collaboration | |
| * Date: 2026-03-25 | |
| */ | |
| #include "ggml_turboquant.h" | |
| #include <cuda_runtime.h> | |
| #include <cuda_fp16.h> | |
| /* ========================================================================= | |
| * Section 1: Constant Memory — Codebooks | |
| * | |
| * Codebooks are tiny (max 16 floats) and read-only, so constant memory | |
| * gives broadcast reads across all threads in a warp. | |
| * ========================================================================= */ | |
| __constant__ float d_codebook_3[8]; | |
| __constant__ float d_codebook_4[16]; | |
| /* ========================================================================= | |
| * Section 2: Device Helper — Find Nearest Centroid | |
| * ========================================================================= */ | |
| __device__ __forceinline__ | |
| uint8_t tq_find_nearest(float val, const float * codebook, int n_levels) { | |
| float best_dist = 1e30f; | |
| uint8_t best_idx = 0; | |
| /* Codebook is sorted, so we could binary search. | |
| * But n_levels <= 16, so linear scan is faster due to no branching. */ | |
| for (int c = 0; c < n_levels; c++) { | |
| float d = (val - codebook[c]); | |
| d = d * d; | |
| if (d < best_dist) { | |
| best_dist = d; | |
| best_idx = (uint8_t)c; | |
| } | |
| } | |
| return best_idx; | |
| } | |
| /* ========================================================================= | |
| * Section 3: Quantize Kernel | |
| * | |
| * One block per vector (token×head). 128 threads per block (one per dim). | |
| * | |
| * Workflow: | |
| * 1. Load input vector into shared memory | |
| * 2. Compute L2 norm via warp reduction | |
| * 3. Each thread computes one element of y = Π · (x/||x||) | |
| * by reading its row of Π from global memory | |
| * 4. Find nearest codebook centroid | |
| * 5. Bit-pack indices cooperatively | |
| * 6. Write output block | |
| * ========================================================================= */ | |
| __global__ void tq_quantize_kernel_tq3( | |
| const float * __restrict__ src, /* [n_vectors × 128] input */ | |
| void * __restrict__ dst, /* [n_vectors × block_tq3] out */ | |
| const float * __restrict__ rotation, /* [128 × 128] rotation matrix */ | |
| int n_vectors | |
| ) { | |
| const int vec_idx = blockIdx.x; | |
| if (vec_idx >= n_vectors) return; | |
| const int tid = threadIdx.x; /* 0..127, one per coordinate */ | |
| const int d = TQ_HEAD_DIM; /* 128 */ | |
| /* Shared memory: input vector + norm */ | |
| __shared__ float s_input[TQ_HEAD_DIM]; | |
| __shared__ float s_norm_sq; | |
| __shared__ uint8_t s_indices[TQ_HEAD_DIM]; | |
| __shared__ uint8_t s_packed[TQ3_INDEX_BYTES]; | |
| /* Step 1: Load input vector */ | |
| s_input[tid] = src[vec_idx * d + tid]; | |
| __syncthreads(); | |
| /* Step 2: Compute L2 norm via parallel reduction */ | |
| float val_sq = s_input[tid] * s_input[tid]; | |
| /* Warp-level reduction first */ | |
| for (int offset = 16; offset > 0; offset >>= 1) { | |
| val_sq += __shfl_down_sync(0xFFFFFFFF, val_sq, offset); | |
| } | |
| /* Cross-warp reduction: lane 0 of each warp writes to shared */ | |
| __shared__ float s_warp_sums[4]; /* 128 threads / 32 = 4 warps */ | |
| if (tid % 32 == 0) { | |
| s_warp_sums[tid / 32] = val_sq; | |
| } | |
| __syncthreads(); | |
| if (tid == 0) { | |
| s_norm_sq = s_warp_sums[0] + s_warp_sums[1] + | |
| s_warp_sums[2] + s_warp_sums[3]; | |
| } | |
| __syncthreads(); | |
| float norm = sqrtf(s_norm_sq); | |
| /* Handle zero vector */ | |
| if (norm < 1e-15f) { | |
| if (tid == 0) { | |
| block_tq3 * blk = (block_tq3 *)((uint8_t *)dst + | |
| vec_idx * sizeof(block_tq3)); | |
| blk->norm = 0.0f; | |
| memset(blk->indices, 0, TQ3_INDEX_BYTES); | |
| } | |
| return; | |
| } | |
| float inv_norm = 1.0f / norm; | |
| /* Step 3: Rotate — each thread computes y[tid] = row tid of Π · x_unit */ | |
| float y_val = 0.0f; | |
| const float * my_row = rotation + tid * d; | |
| for (int j = 0; j < d; j++) { | |
| y_val += my_row[j] * s_input[j] * inv_norm; | |
| } | |
| /* Step 4: Find nearest codebook centroid */ | |
| s_indices[tid] = tq_find_nearest(y_val, d_codebook_3, 8); | |
| __syncthreads(); | |
| /* Step 5: Cooperative bit-packing (3-bit) */ | |
| /* Each thread packs its own 3 bits into the shared packed array */ | |
| if (tid == 0) { | |
| /* Clear output */ | |
| for (int i = 0; i < TQ3_INDEX_BYTES; i++) s_packed[i] = 0; | |
| } | |
| __syncthreads(); | |
| { | |
| int bit_start = tid * 3; | |
| uint8_t val = s_indices[tid]; | |
| for (int b = 0; b < 3; b++) { | |
| int bit_pos = bit_start + b; | |
| if (val & (1 << b)) { | |
| atomicOr((unsigned int *)(s_packed + (bit_pos / 8) - (bit_pos / 8) % 4), | |
| (unsigned int)(1 << (bit_pos % 32))); | |
| } | |
| } | |
| } | |
| /* Alternative: single-threaded packing is simpler and fast enough | |
| * for 48 bytes. Use if atomicOr alignment is problematic. */ | |
| __syncthreads(); | |
| /* Step 6: Write output */ | |
| if (tid == 0) { | |
| block_tq3 * blk = (block_tq3 *)((uint8_t *)dst + | |
| vec_idx * sizeof(block_tq3)); | |
| blk->norm = norm; | |
| for (int i = 0; i < TQ3_INDEX_BYTES; i++) { | |
| blk->indices[i] = s_packed[i]; | |
| } | |
| } | |
| } | |
| /* ========================================================================= | |
| * Section 4: Dequantize Kernel | |
| * | |
| * Critical path for flash attention: KV cache read → dequantize → attention. | |
| * Must be as fast as possible. | |
| * | |
| * Workflow: | |
| * 1. Thread 0 loads norm and packed indices | |
| * 2. Each thread unpacks its own index | |
| * 3. Each thread looks up codebook centroid → y_hat[tid] | |
| * 4. Each thread computes x_hat[tid] = (Π^T · y_hat)[tid] | |
| * = sum_j Π[j][tid] * y_hat[j] | |
| * 5. Scale by norm and write output | |
| * ========================================================================= */ | |
| __global__ void tq_dequantize_kernel_tq3( | |
| const void * __restrict__ src, /* [n_vectors × block_tq3] in */ | |
| float * __restrict__ dst, /* [n_vectors × 128] output */ | |
| const float * __restrict__ rotation, /* [128 × 128] rotation matrix */ | |
| int n_vectors | |
| ) { | |
| const int vec_idx = blockIdx.x; | |
| if (vec_idx >= n_vectors) return; | |
| const int tid = threadIdx.x; | |
| const int d = TQ_HEAD_DIM; | |
| __shared__ float s_y_hat[TQ_HEAD_DIM]; | |
| __shared__ uint8_t s_packed[TQ3_INDEX_BYTES]; | |
| __shared__ float s_norm; | |
| /* Step 1: Load block data */ | |
| const block_tq3 * blk = (const block_tq3 *)((const uint8_t *)src + | |
| vec_idx * sizeof(block_tq3)); | |
| if (tid == 0) { | |
| s_norm = blk->norm; | |
| for (int i = 0; i < TQ3_INDEX_BYTES; i++) { | |
| s_packed[i] = blk->indices[i]; | |
| } | |
| } | |
| __syncthreads(); | |
| /* Handle zero vector */ | |
| if (fabsf(s_norm) < 1e-15f) { | |
| dst[vec_idx * d + tid] = 0.0f; | |
| return; | |
| } | |
| /* Step 2: Each thread unpacks its own 3-bit index */ | |
| uint8_t my_idx; | |
| { | |
| int bit_start = tid * 3; | |
| my_idx = 0; | |
| for (int b = 0; b < 3; b++) { | |
| int bit_pos = bit_start + b; | |
| if (s_packed[bit_pos / 8] & (1 << (bit_pos % 8))) { | |
| my_idx |= (1 << b); | |
| } | |
| } | |
| } | |
| /* Step 3: Look up codebook centroid */ | |
| s_y_hat[tid] = d_codebook_3[my_idx]; | |
| __syncthreads(); | |
| /* Step 4: Rotate back — x_hat[tid] = sum_j Π^T[tid][j] * y_hat[j] | |
| * = sum_j Π[j][tid] * y_hat[j] */ | |
| float x_val = 0.0f; | |
| for (int j = 0; j < d; j++) { | |
| x_val += rotation[j * d + tid] * s_y_hat[j]; | |
| } | |
| /* Step 5: Scale and write */ | |
| dst[vec_idx * d + tid] = x_val * s_norm; | |
| } | |
| /* ========================================================================= | |
| * Section 5: Host-side Launch Wrappers | |
| * ========================================================================= */ | |
| /* Initialize constant memory with codebooks (call once at startup) */ | |
| extern "C" | |
| void tq_cuda_init_codebooks(void) { | |
| cudaMemcpyToSymbol(d_codebook_3, TQ_CODEBOOK_3, | |
| 8 * sizeof(float), 0, cudaMemcpyHostToDevice); | |
| cudaMemcpyToSymbol(d_codebook_4, TQ_CODEBOOK_4, | |
| 16 * sizeof(float), 0, cudaMemcpyHostToDevice); | |
| } | |
| /* Quantize n_vectors on GPU */ | |
| extern "C" | |
| void tq_cuda_quantize_tq3( | |
| const float * d_src, /* Device: [n_vectors × 128] */ | |
| void * d_dst, /* Device: [n_vectors × sizeof(block_tq3)] */ | |
| const float * d_rotation, /* Device: [128 × 128] rotation matrix */ | |
| int n_vectors, | |
| cudaStream_t stream | |
| ) { | |
| if (n_vectors <= 0) return; | |
| tq_quantize_kernel_tq3<<<n_vectors, TQ_HEAD_DIM, 0, stream>>>( | |
| d_src, d_dst, d_rotation, n_vectors | |
| ); | |
| } | |
| /* Dequantize n_vectors on GPU */ | |
| extern "C" | |
| void tq_cuda_dequantize_tq3( | |
| const void * d_src, | |
| float * d_dst, | |
| const float * d_rotation, | |
| int n_vectors, | |
| cudaStream_t stream | |
| ) { | |
| if (n_vectors <= 0) return; | |
| tq_dequantize_kernel_tq3<<<n_vectors, TQ_HEAD_DIM, 0, stream>>>( | |
| d_src, d_dst, d_rotation, n_vectors | |
| ); | |
| } | |
| /* ========================================================================= | |
| * Section 6: Flash Attention Integration Kernel (Fused Dequantize + Dot) | |
| * | |
| * For maximum performance, instead of dequantizing KV cache vectors to | |
| * FP16 and then running flash attention, we can fuse the dequantize | |
| * directly into the attention dot product. | |
| * | |
| * This computes: dot(Q_vec, dequant(KV_block)) without materializing | |
| * the full dequantized vector in memory. | |
| * | |
| * The math: Q · x_hat = Q · (||x|| · Π^T · y_hat) | |
| * = ||x|| · (Q · Π^T) · y_hat | |
| * = ||x|| · Q_rotated · y_hat | |
| * | |
| * Where Q_rotated = Q · Π^T can be precomputed once per query. | |
| * Then the dot product with y_hat only needs codebook lookups. | |
| * | |
| * This is the ultimate optimization — reduces the dequant+dot to: | |
| * 1. One codebook lookup per coordinate (register-cached) | |
| * 2. One multiply-accumulate per coordinate | |
| * 3. One scalar multiply by norm | |
| * | |
| * NOTE: This kernel is a forward-looking design for when flash attention | |
| * integration is ready. The non-fused path (dequant → standard FA) works | |
| * as the initial integration point. | |
| * ========================================================================= */ | |
| __global__ void tq_fused_dot_tq3( | |
| const float * __restrict__ q_rotated, /* [n_queries × 128] = Q · Π^T */ | |
| const void * __restrict__ kv_blocks, /* [n_kv × block_tq3] */ | |
| float * __restrict__ scores, /* [n_queries × n_kv] output */ | |
| int n_queries, | |
| int n_kv | |
| ) { | |
| /* Grid: (n_kv, n_queries), Block: (128) */ | |
| const int kv_idx = blockIdx.x; | |
| const int q_idx = blockIdx.y; | |
| if (kv_idx >= n_kv || q_idx >= n_queries) return; | |
| const int tid = threadIdx.x; | |
| const int d = TQ_HEAD_DIM; | |
| /* Load KV block */ | |
| const block_tq3 * blk = (const block_tq3 *)((const uint8_t *)kv_blocks + | |
| kv_idx * sizeof(block_tq3)); | |
| __shared__ float s_norm; | |
| __shared__ uint8_t s_packed[TQ3_INDEX_BYTES]; | |
| if (tid == 0) { | |
| s_norm = blk->norm; | |
| for (int i = 0; i < TQ3_INDEX_BYTES; i++) { | |
| s_packed[i] = blk->indices[i]; | |
| } | |
| } | |
| __syncthreads(); | |
| /* Unpack index for this coordinate */ | |
| uint8_t my_idx; | |
| { | |
| int bit_start = tid * 3; | |
| my_idx = 0; | |
| for (int b = 0; b < 3; b++) { | |
| int bit_pos = bit_start + b; | |
| if (s_packed[bit_pos / 8] & (1 << (bit_pos % 8))) { | |
| my_idx |= (1 << b); | |
| } | |
| } | |
| } | |
| /* Lookup centroid and multiply with pre-rotated query */ | |
| float y_hat_val = d_codebook_3[my_idx]; | |
| float q_val = q_rotated[q_idx * d + tid]; | |
| float partial = q_val * y_hat_val; | |
| /* Warp reduction for dot product */ | |
| for (int offset = 16; offset > 0; offset >>= 1) { | |
| partial += __shfl_down_sync(0xFFFFFFFF, partial, offset); | |
| } | |
| /* Cross-warp reduction */ | |
| __shared__ float s_warp_dots[4]; | |
| if (tid % 32 == 0) { | |
| s_warp_dots[tid / 32] = partial; | |
| } | |
| __syncthreads(); | |
| if (tid == 0) { | |
| float dot = s_warp_dots[0] + s_warp_dots[1] + | |
| s_warp_dots[2] + s_warp_dots[3]; | |
| scores[q_idx * n_kv + kv_idx] = dot * s_norm; | |
| } | |
| } | |
| /* Host wrapper for fused dot product */ | |
| extern "C" | |
| void tq_cuda_fused_dot_tq3( | |
| const float * d_q_rotated, | |
| const void * d_kv_blocks, | |
| float * d_scores, | |
| int n_queries, | |
| int n_kv, | |
| cudaStream_t stream | |
| ) { | |
| dim3 grid(n_kv, n_queries); | |
| dim3 block(TQ_HEAD_DIM); | |
| tq_fused_dot_tq3<<<grid, block, 0, stream>>>( | |
| d_q_rotated, d_kv_blocks, d_scores, n_queries, n_kv | |
| ); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * TurboQuant: Near-Optimal KV Cache Quantization for ik_llama.cpp | |
| * ================================================================ | |
| * Reference: Zandieh et al., ICLR 2026 (arXiv:2504.19874) | |
| * | |
| * Implements PolarQuant (Algorithm 1) + QJL error correction for | |
| * KV cache compression to 3-3.5 bits per value with near-zero | |
| * accuracy loss. | |
| * | |
| * Built for Nexus Grove (ng-01) — designed to integrate into | |
| * ik_llama.cpp's existing KV cache quantization pipeline alongside | |
| * the existing q4_0, q8_0, etc. types. | |
| * | |
| * Authors: Jim Sullivan / Claude collaboration | |
| * Date: 2026-03-25 | |
| */ | |
| #ifndef GGML_TURBOQUANT_H | |
| #define GGML_TURBOQUANT_H | |
| #include <stdint.h> | |
| #include <stddef.h> | |
| #ifdef __cplusplus | |
| extern "C" { | |
| #endif | |
| /* ========================================================================= | |
| * Section 1: Configuration Constants | |
| * ========================================================================= */ | |
| /* Head dimension — standard for modern transformers (Qwen, Llama, etc.) */ | |
| #define TQ_HEAD_DIM 128 | |
| /* Supported bit-widths */ | |
| #define TQ_BITS_2 2 | |
| #define TQ_BITS_3 3 | |
| #define TQ_BITS_4 4 | |
| /* Default operational bit-width (paper's quality-neutral sweet spot) */ | |
| #define TQ_DEFAULT_BITS 3 | |
| /* Number of quantization levels per bit-width */ | |
| #define TQ_LEVELS_2 4 | |
| #define TQ_LEVELS_3 8 | |
| #define TQ_LEVELS_4 16 | |
| /* Rotation matrix seed — deterministic for reproducibility across sessions */ | |
| #define TQ_ROTATION_SEED 42 | |
| /* ========================================================================= | |
| * Section 2: Pre-computed Lloyd-Max Codebooks (from turboquant_codebooks.json) | |
| * | |
| * These are optimal scalar quantizer centroids for the Beta distribution | |
| * induced by random rotation of unit vectors in R^128. | |
| * Computed via Lloyd-Max algorithm per Theorem 1 of the paper. | |
| * ========================================================================= */ | |
| /* 2-bit codebook: 4 centroids */ | |
| static const float TQ_CODEBOOK_2[4] = { | |
| -0.13311451677280386f, | |
| -0.04002746648341520f, | |
| 0.04002746648341517f, | |
| 0.13311451677280380f | |
| }; | |
| /* 3-bit codebook: 8 centroids */ | |
| static const float TQ_CODEBOOK_3[8] = { | |
| -0.18904037194348838f, | |
| -0.11879501670185091f, | |
| -0.06702922184405663f, | |
| -0.02174971334976657f, | |
| 0.02174971334976654f, | |
| 0.06702922184405660f, | |
| 0.11879501670185087f, | |
| 0.18904037194348833f | |
| }; | |
| /* 4-bit codebook: 16 centroids */ | |
| static const float TQ_CODEBOOK_4[16] = { | |
| -0.23961253307138700f, | |
| -0.18317108415643454f, | |
| -0.14430970076906538f, | |
| -0.11276586366299288f, | |
| -0.08507481024405737f, | |
| -0.05962130616889217f, | |
| -0.03539017687270855f, | |
| -0.01173284981923122f, | |
| 0.01173284981923120f, | |
| 0.03539017687270851f, | |
| 0.05962130616889214f, | |
| 0.08507481024405730f, | |
| 0.11276586366299284f, | |
| 0.14430970076906535f, | |
| 0.18317108415643450f, | |
| 0.23961253307138697f | |
| }; | |
| /* ========================================================================= | |
| * Section 3: Data Structures | |
| * ========================================================================= */ | |
| /* | |
| * TurboQuant quantized block — stores one quantized head vector. | |
| * | |
| * Memory layout for TQ3 (3-bit, d=128): | |
| * - norm: 4 bytes (float32, original L2 norm) | |
| * - indices: 48 bytes (128 values × 3 bits = 384 bits = 48 bytes) | |
| * Total: 52 bytes per vector | |
| * | |
| * Compare to FP16: 128 × 2 = 256 bytes per vector → 4.9x compression | |
| * | |
| * The indices are bit-packed: for b-bit quantization, each coordinate | |
| * index (0 to 2^b - 1) is stored in exactly b bits, packed sequentially. | |
| */ | |
| /* Block size for TQ3: one head_dim vector */ | |
| #define TQ3_BLOCK_SIZE TQ_HEAD_DIM | |
| #define TQ3_BITS_PER_VAL 3 | |
| #define TQ3_INDEX_BYTES ((TQ3_BLOCK_SIZE * TQ3_BITS_PER_VAL + 7) / 8) /* 48 */ | |
| typedef struct { | |
| float norm; /* Original L2 norm of the vector */ | |
| uint8_t indices[TQ3_INDEX_BYTES]; /* Bit-packed codebook indices */ | |
| } block_tq3; | |
| /* Verify size at compile time */ | |
| _Static_assert(sizeof(block_tq3) == 4 + TQ3_INDEX_BYTES, | |
| "block_tq3 size mismatch"); | |
| /* Block size for TQ4: one head_dim vector */ | |
| #define TQ4_BLOCK_SIZE TQ_HEAD_DIM | |
| #define TQ4_BITS_PER_VAL 4 | |
| #define TQ4_INDEX_BYTES ((TQ4_BLOCK_SIZE * TQ4_BITS_PER_VAL + 7) / 8) /* 64 */ | |
| typedef struct { | |
| float norm; | |
| uint8_t indices[TQ4_INDEX_BYTES]; | |
| } block_tq4; | |
| _Static_assert(sizeof(block_tq4) == 4 + TQ4_INDEX_BYTES, | |
| "block_tq4 size mismatch"); | |
| /* | |
| * Rotation matrix context — generated once at KV cache init, | |
| * reused for all quantize/dequantize operations. | |
| * | |
| * The matrix is d×d orthogonal, generated via QR decomposition | |
| * of a seeded random Gaussian matrix (Algorithm 1, Line 2). | |
| * | |
| * For d=128, this is 128×128×4 = 64 KB per rotation context. | |
| * Two contexts are needed (one for K cache, one for V cache) = 128 KB total. | |
| * Negligible compared to the KV cache itself. | |
| */ | |
| typedef struct { | |
| int d; /* Dimension (TQ_HEAD_DIM) */ | |
| int bits; /* Bit-width (2, 3, or 4) */ | |
| int n_levels; /* 2^bits */ | |
| const float * codebook; /* Pointer to static codebook */ | |
| float rotation[TQ_HEAD_DIM * TQ_HEAD_DIM]; /* Orthogonal rotation Π */ | |
| } tq_context; | |
| /* ========================================================================= | |
| * Section 4: Core API — CPU Reference Implementation | |
| * ========================================================================= */ | |
| /* | |
| * Initialize a TurboQuant context with a given bit-width and seed. | |
| * Generates the rotation matrix via QR decomposition of seeded Gaussian. | |
| * | |
| * @param ctx Output context | |
| * @param bits Bit-width (2, 3, or 4) | |
| * @param seed RNG seed for rotation matrix (use TQ_ROTATION_SEED) | |
| * @return 0 on success, -1 on error | |
| */ | |
| int tq_context_init(tq_context * ctx, int bits, uint64_t seed); | |
| /* | |
| * Quantize a single head-dimension vector (Algorithm 1). | |
| * | |
| * Steps: | |
| * 1. Store ||x||_2 as norm | |
| * 2. Normalize: x_unit = x / ||x|| | |
| * 3. Rotate: y = Π · x_unit | |
| * 4. For each y_j, find nearest codebook centroid index | |
| * 5. Bit-pack indices into output block | |
| * | |
| * @param ctx Initialized TQ context | |
| * @param src Input vector, float[TQ_HEAD_DIM] | |
| * @param dst Output block (block_tq3 or block_tq4, cast to void*) | |
| */ | |
| void tq_quantize(const tq_context * ctx, const float * src, void * dst); | |
| /* | |
| * Dequantize a single block back to float vector (Algorithm 1 inverse). | |
| * | |
| * Steps: | |
| * 1. Unpack bit-packed indices | |
| * 2. Map indices to codebook centroids: y_hat_j = codebook[idx_j] | |
| * 3. Rotate back: x_hat = Π^T · y_hat | |
| * 4. Scale by stored norm | |
| * | |
| * @param ctx Initialized TQ context | |
| * @param src Input block (block_tq3 or block_tq4, cast to const void*) | |
| * @param dst Output vector, float[TQ_HEAD_DIM] | |
| */ | |
| void tq_dequantize(const tq_context * ctx, const void * src, float * dst); | |
| /* | |
| * Quantize a batch of vectors (e.g., all KV heads for one token in one layer). | |
| * | |
| * @param ctx Initialized TQ context | |
| * @param src Input: n_vectors × TQ_HEAD_DIM floats (row-major) | |
| * @param dst Output: n_vectors × sizeof(block_tqN) bytes | |
| * @param n_vectors Number of vectors to quantize | |
| */ | |
| void tq_quantize_batch(const tq_context * ctx, const float * src, | |
| void * dst, int n_vectors); | |
| /* | |
| * Dequantize a batch of vectors. | |
| */ | |
| void tq_dequantize_batch(const tq_context * ctx, const void * src, | |
| float * dst, int n_vectors); | |
| /* ========================================================================= | |
| * Section 5: Utility Functions | |
| * ========================================================================= */ | |
| /* | |
| * Returns the size in bytes of one quantized block for the given bit-width. | |
| */ | |
| static inline size_t tq_block_size(int bits) { | |
| switch (bits) { | |
| case 3: return sizeof(block_tq3); | |
| case 4: return sizeof(block_tq4); | |
| default: return 0; | |
| } | |
| } | |
| /* | |
| * Returns the compression ratio vs FP16 for the given bit-width. | |
| */ | |
| static inline float tq_compression_ratio(int bits) { | |
| size_t fp16_size = TQ_HEAD_DIM * 2; /* 256 bytes */ | |
| size_t tq_size = tq_block_size(bits); | |
| if (tq_size == 0) return 0.0f; | |
| return (float)fp16_size / (float)tq_size; | |
| } | |
| /* ========================================================================= | |
| * Section 6: Bit-packing Helpers | |
| * ========================================================================= */ | |
| /* | |
| * Pack an array of b-bit indices into a byte array. | |
| * indices[i] must be in range [0, 2^b - 1]. | |
| */ | |
| void tq_pack_indices(const uint8_t * indices, uint8_t * packed, | |
| int n_values, int bits); | |
| /* | |
| * Unpack a byte array into an array of b-bit indices. | |
| */ | |
| void tq_unpack_indices(const uint8_t * packed, uint8_t * indices, | |
| int n_values, int bits); | |
| #ifdef __cplusplus | |
| } | |
| #endif | |
| #endif /* GGML_TURBOQUANT_H */ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * TurboQuant: Test Harness | |
| * ========================= | |
| * Validates the C implementation against known-good values from | |
| * the Python prototype. Compile and run on ng-01 before touching | |
| * any ik_llama.cpp integration. | |
| * | |
| * Build: | |
| * gcc -O2 -o tq_test ggml_turboquant.c tq_test.c -lm | |
| * | |
| * Run: | |
| * ./tq_test | |
| * | |
| * Authors: Jim Sullivan / Claude collaboration | |
| * Date: 2026-03-25 | |
| */ | |
| #include "ggml_turboquant.h" | |
| #include <stdio.h> | |
| #include <stdlib.h> | |
| #include <math.h> | |
| #include <string.h> | |
| #include <time.h> | |
| #define PASS "\033[32m✓ PASS\033[0m" | |
| #define FAIL "\033[31m✗ FAIL\033[0m" | |
| static int tests_passed = 0; | |
| static int tests_failed = 0; | |
| static void check(const char * name, int condition) { | |
| if (condition) { | |
| printf(" %s %s\n", PASS, name); | |
| tests_passed++; | |
| } else { | |
| printf(" %s %s\n", FAIL, name); | |
| tests_failed++; | |
| } | |
| } | |
| /* Simple seeded PRNG for test vectors (matches Python's numpy seed behavior | |
| * closely enough for MSE validation, though not bit-identical) */ | |
| static uint64_t test_rng_state = 0; | |
| static float test_randn(void) { | |
| /* xorshift64* */ | |
| test_rng_state ^= test_rng_state >> 12; | |
| test_rng_state ^= test_rng_state << 25; | |
| test_rng_state ^= test_rng_state >> 27; | |
| uint64_t r = test_rng_state * 0x2545F4914F6CDD1DULL; | |
| /* Convert to uniform [0,1) */ | |
| double u1 = (double)(r >> 11) / (double)(1ULL << 53); | |
| test_rng_state ^= test_rng_state >> 12; | |
| test_rng_state ^= test_rng_state << 25; | |
| test_rng_state ^= test_rng_state >> 27; | |
| r = test_rng_state * 0x2545F4914F6CDD1DULL; | |
| double u2 = (double)(r >> 11) / (double)(1ULL << 53); | |
| if (u1 < 1e-15) u1 = 1e-15; | |
| return (float)(sqrt(-2.0 * log(u1)) * cos(2.0 * 3.14159265358979323846 * u2)); | |
| } | |
| /* ========================================================================= */ | |
| int main(void) { | |
| printf("=========================================================\n"); | |
| printf("TurboQuant C Implementation — Validation Suite\n"); | |
| printf("=========================================================\n\n"); | |
| /* ----------------------------------------------------------------- | |
| * Test 1: Context initialization | |
| * ----------------------------------------------------------------- */ | |
| printf("[Test 1] Context Initialization\n"); | |
| printf("-------------------------------------------------\n"); | |
| tq_context ctx3, ctx4; | |
| int rc3 = tq_context_init(&ctx3, 3, TQ_ROTATION_SEED); | |
| int rc4 = tq_context_init(&ctx4, 4, TQ_ROTATION_SEED); | |
| check("TQ3 context init returns 0", rc3 == 0); | |
| check("TQ4 context init returns 0", rc4 == 0); | |
| check("TQ3 has 8 levels", ctx3.n_levels == 8); | |
| check("TQ4 has 16 levels", ctx4.n_levels == 16); | |
| check("TQ3 dimension is 128", ctx3.d == 128); | |
| /* Verify rotation matrix is orthogonal: Π^T · Π ≈ I */ | |
| float dot_00 = 0.0f, dot_01 = 0.0f; | |
| for (int k = 0; k < TQ_HEAD_DIM; k++) { | |
| dot_00 += ctx3.rotation[k * TQ_HEAD_DIM + 0] * | |
| ctx3.rotation[k * TQ_HEAD_DIM + 0]; | |
| dot_01 += ctx3.rotation[k * TQ_HEAD_DIM + 0] * | |
| ctx3.rotation[k * TQ_HEAD_DIM + 1]; | |
| } | |
| check("Rotation col 0 has unit norm", fabsf(dot_00 - 1.0f) < 1e-4f); | |
| check("Rotation cols 0,1 are orthogonal", fabsf(dot_01) < 1e-4f); | |
| /* Invalid bit-width should fail */ | |
| tq_context ctx_bad; | |
| int rc_bad = tq_context_init(&ctx_bad, 5, 0); | |
| check("Invalid bit-width returns -1", rc_bad == -1); | |
| /* ----------------------------------------------------------------- | |
| * Test 2: Bit-packing round-trip | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 2] Bit-packing Round-trip\n"); | |
| printf("-------------------------------------------------\n"); | |
| /* 3-bit packing */ | |
| uint8_t orig3[TQ_HEAD_DIM], unpacked3[TQ_HEAD_DIM]; | |
| uint8_t packed3[TQ3_INDEX_BYTES]; | |
| for (int i = 0; i < TQ_HEAD_DIM; i++) { | |
| orig3[i] = (uint8_t)(i % 8); /* 0-7 for 3-bit */ | |
| } | |
| tq_pack_indices(orig3, packed3, TQ_HEAD_DIM, 3); | |
| tq_unpack_indices(packed3, unpacked3, TQ_HEAD_DIM, 3); | |
| int pack3_ok = 1; | |
| for (int i = 0; i < TQ_HEAD_DIM; i++) { | |
| if (orig3[i] != unpacked3[i]) { pack3_ok = 0; break; } | |
| } | |
| check("3-bit pack/unpack round-trip", pack3_ok); | |
| /* 4-bit packing */ | |
| uint8_t orig4[TQ_HEAD_DIM], unpacked4[TQ_HEAD_DIM]; | |
| uint8_t packed4[TQ4_INDEX_BYTES]; | |
| for (int i = 0; i < TQ_HEAD_DIM; i++) { | |
| orig4[i] = (uint8_t)(i % 16); /* 0-15 for 4-bit */ | |
| } | |
| tq_pack_indices(orig4, packed4, TQ_HEAD_DIM, 4); | |
| tq_unpack_indices(packed4, unpacked4, TQ_HEAD_DIM, 4); | |
| int pack4_ok = 1; | |
| for (int i = 0; i < TQ_HEAD_DIM; i++) { | |
| if (orig4[i] != unpacked4[i]) { pack4_ok = 0; break; } | |
| } | |
| check("4-bit pack/unpack round-trip", pack4_ok); | |
| /* ----------------------------------------------------------------- | |
| * Test 3: Quantize/Dequantize round-trip MSE | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 3] Quantize/Dequantize Round-trip MSE\n"); | |
| printf("-------------------------------------------------\n"); | |
| /* Paper's expected MSE for d=128 (from Theorem 1): | |
| * b=3: ~0.034 | |
| * b=4: ~0.0093 | |
| */ | |
| test_rng_state = 12345; | |
| int n_test_vectors = 1000; | |
| for (int bits = 3; bits <= 4; bits++) { | |
| tq_context * ctx = (bits == 3) ? &ctx3 : &ctx4; | |
| float paper_mse = (bits == 3) ? 0.034f : 0.0093f; | |
| float total_mse = 0.0f; | |
| size_t blk_size = tq_block_size(bits); | |
| uint8_t block_buf[sizeof(block_tq4)]; /* Large enough for either */ | |
| for (int v = 0; v < n_test_vectors; v++) { | |
| /* Generate random unit vector */ | |
| float x[TQ_HEAD_DIM], x_hat[TQ_HEAD_DIM]; | |
| float norm = 0.0f; | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) { | |
| x[j] = test_randn(); | |
| norm += x[j] * x[j]; | |
| } | |
| norm = sqrtf(norm); | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) x[j] /= norm; | |
| tq_quantize(ctx, x, block_buf); | |
| tq_dequantize(ctx, block_buf, x_hat); | |
| float mse = 0.0f; | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) { | |
| float diff = x[j] - x_hat[j]; | |
| mse += diff * diff; | |
| } | |
| total_mse += mse; | |
| } | |
| float avg_mse = total_mse / n_test_vectors; | |
| /* Allow 3x tolerance since our PRNG differs from numpy */ | |
| int mse_ok = (avg_mse < paper_mse * 3.0f) && (avg_mse > paper_mse * 0.3f); | |
| printf(" b=%d: Avg MSE = %.6f (paper ≈ %.4f) ratio = %.2f\n", | |
| bits, avg_mse, paper_mse, avg_mse / paper_mse); | |
| check(bits == 3 ? "TQ3 MSE within 3x of paper" : | |
| "TQ4 MSE within 3x of paper", mse_ok); | |
| } | |
| /* ----------------------------------------------------------------- | |
| * Test 4: Zero vector handling | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 4] Zero Vector Handling\n"); | |
| printf("-------------------------------------------------\n"); | |
| float zeros[TQ_HEAD_DIM]; | |
| float zeros_out[TQ_HEAD_DIM]; | |
| uint8_t zero_block[sizeof(block_tq3)]; | |
| memset(zeros, 0, sizeof(zeros)); | |
| tq_quantize(&ctx3, zeros, zero_block); | |
| tq_dequantize(&ctx3, zero_block, zeros_out); | |
| float zero_norm = 0.0f; | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) { | |
| zero_norm += zeros_out[j] * zeros_out[j]; | |
| } | |
| check("Zero vector round-trips to zero", zero_norm < 1e-10f); | |
| /* ----------------------------------------------------------------- | |
| * Test 5: Norm preservation | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 5] Norm Preservation\n"); | |
| printf("-------------------------------------------------\n"); | |
| test_rng_state = 99999; | |
| float x_norm_test[TQ_HEAD_DIM], x_hat_norm[TQ_HEAD_DIM]; | |
| uint8_t norm_block[sizeof(block_tq3)]; | |
| /* Create vector with known norm = 3.7 */ | |
| float target_norm = 3.7f; | |
| float raw_norm = 0.0f; | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) { | |
| x_norm_test[j] = test_randn(); | |
| raw_norm += x_norm_test[j] * x_norm_test[j]; | |
| } | |
| raw_norm = sqrtf(raw_norm); | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) { | |
| x_norm_test[j] *= target_norm / raw_norm; | |
| } | |
| tq_quantize(&ctx3, x_norm_test, norm_block); | |
| tq_dequantize(&ctx3, norm_block, x_hat_norm); | |
| float recon_norm = 0.0f; | |
| for (int j = 0; j < TQ_HEAD_DIM; j++) { | |
| recon_norm += x_hat_norm[j] * x_hat_norm[j]; | |
| } | |
| recon_norm = sqrtf(recon_norm); | |
| printf(" Original norm: %.4f Reconstructed norm: %.4f\n", | |
| target_norm, recon_norm); | |
| check("Norm preserved within 10%", | |
| fabsf(recon_norm - target_norm) / target_norm < 0.10f); | |
| /* ----------------------------------------------------------------- | |
| * Test 6: Compression ratio verification | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 6] Compression Ratios\n"); | |
| printf("-------------------------------------------------\n"); | |
| size_t fp16_size = TQ_HEAD_DIM * 2; /* 256 bytes */ | |
| float ratio3 = tq_compression_ratio(3); | |
| float ratio4 = tq_compression_ratio(4); | |
| printf(" TQ3: %zu bytes → %.1fx vs FP16 (%zu bytes)\n", | |
| tq_block_size(3), ratio3, fp16_size); | |
| printf(" TQ4: %zu bytes → %.1fx vs FP16 (%zu bytes)\n", | |
| tq_block_size(4), ratio4, fp16_size); | |
| check("TQ3 compression > 4x", ratio3 > 4.0f); | |
| check("TQ4 compression > 3x", ratio4 > 3.0f); | |
| /* ----------------------------------------------------------------- | |
| * Test 7: Batch operations | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 7] Batch Quantize/Dequantize\n"); | |
| printf("-------------------------------------------------\n"); | |
| int batch_size = 8; /* 8 KV heads typical for GQA */ | |
| float batch_in[8 * TQ_HEAD_DIM]; | |
| float batch_out[8 * TQ_HEAD_DIM]; | |
| uint8_t batch_blocks[8 * sizeof(block_tq3)]; | |
| test_rng_state = 42424242; | |
| for (int i = 0; i < batch_size * TQ_HEAD_DIM; i++) { | |
| batch_in[i] = test_randn() * 0.1f; | |
| } | |
| tq_quantize_batch(&ctx3, batch_in, batch_blocks, batch_size); | |
| tq_dequantize_batch(&ctx3, batch_blocks, batch_out, batch_size); | |
| float batch_mse = 0.0f; | |
| for (int i = 0; i < batch_size * TQ_HEAD_DIM; i++) { | |
| float diff = batch_in[i] - batch_out[i]; | |
| batch_mse += diff * diff; | |
| } | |
| batch_mse /= batch_size; | |
| printf(" Batch MSE (8 vectors): %.6f\n", batch_mse); | |
| check("Batch round-trip MSE reasonable", batch_mse < 0.1f); | |
| /* ----------------------------------------------------------------- | |
| * Test 8: Speed benchmark | |
| * ----------------------------------------------------------------- */ | |
| printf("\n[Test 8] Speed Benchmark (10000 vectors)\n"); | |
| printf("-------------------------------------------------\n"); | |
| int bench_n = 10000; | |
| float * bench_in = (float *)malloc(bench_n * TQ_HEAD_DIM * sizeof(float)); | |
| uint8_t * bench_blocks = (uint8_t *)malloc(bench_n * sizeof(block_tq3)); | |
| float * bench_out = (float *)malloc(bench_n * TQ_HEAD_DIM * sizeof(float)); | |
| test_rng_state = 777; | |
| for (int i = 0; i < bench_n * TQ_HEAD_DIM; i++) { | |
| bench_in[i] = test_randn(); | |
| } | |
| clock_t t0 = clock(); | |
| tq_quantize_batch(&ctx3, bench_in, bench_blocks, bench_n); | |
| clock_t t1 = clock(); | |
| tq_dequantize_batch(&ctx3, bench_blocks, bench_out, bench_n); | |
| clock_t t2 = clock(); | |
| double quant_ms = (double)(t1 - t0) / CLOCKS_PER_SEC * 1000.0; | |
| double dequant_ms = (double)(t2 - t1) / CLOCKS_PER_SEC * 1000.0; | |
| printf(" Quantize: %.1f ms (%.0f vectors/sec)\n", | |
| quant_ms, bench_n / (quant_ms / 1000.0)); | |
| printf(" Dequantize: %.1f ms (%.0f vectors/sec)\n", | |
| dequant_ms, bench_n / (dequant_ms / 1000.0)); | |
| /* CPU speed check: should manage at least 1000 vec/s even unoptimized */ | |
| check("Quantize speed > 1000 vec/s", | |
| bench_n / (quant_ms / 1000.0) > 1000.0); | |
| free(bench_in); | |
| free(bench_blocks); | |
| free(bench_out); | |
| /* ----------------------------------------------------------------- | |
| * Summary | |
| * ----------------------------------------------------------------- */ | |
| printf("\n=========================================================\n"); | |
| printf("Results: %d passed, %d failed\n", tests_passed, tests_failed); | |
| printf("=========================================================\n"); | |
| return tests_failed > 0 ? 1 : 0; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate | |
| ========================================================================= | |
| Python prototype implementation from the ICLR 2026 paper by Zandieh et al. | |
| (arXiv:2504.19874) | |
| Implements: | |
| - Algorithm 1: TurboQuant_mse (MSE-optimal quantization) | |
| - Algorithm 2: TurboQuant_prod (Unbiased inner-product-optimal quantization) | |
| - QJL: 1-bit Quantized Johnson-Lindenstrauss transform | |
| Built for Nexus Grove (ng-01) as a proof-of-concept prototype. | |
| Not a CUDA kernel — this is NumPy reference code for validating the math | |
| before integration into ik_llama.cpp. | |
| Author: Jim / Claude collaboration | |
| Date: 2026-03-25 | |
| """ | |
| import numpy as np | |
| from scipy.special import gamma as gamma_fn | |
| from scipy.optimize import minimize_scalar | |
| from typing import Tuple, Optional | |
| import time | |
| import json | |
| # ============================================================================= | |
| # Section 1: Beta Distribution PDF for Hypersphere Coordinates (Lemma 1) | |
| # ============================================================================= | |
| def beta_pdf(x: np.ndarray, d: int) -> np.ndarray: | |
| """ | |
| PDF of a single coordinate of a uniformly random point on S^{d-1}. | |
| From Lemma 1: | |
| f_X(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2) | |
| For x in [-1, 1]. | |
| In high dimensions this converges to N(0, 1/d). | |
| """ | |
| if d <= 2: | |
| raise ValueError("Dimension d must be >= 3 for this distribution") | |
| coeff = gamma_fn(d / 2.0) / (np.sqrt(np.pi) * gamma_fn((d - 1) / 2.0)) | |
| # Clip to avoid numerical issues at boundaries | |
| x_clipped = np.clip(x, -1 + 1e-15, 1 - 1e-15) | |
| return coeff * np.power(1.0 - x_clipped**2, (d - 3) / 2.0) | |
| # ============================================================================= | |
| # Section 2: Lloyd-Max Optimal Scalar Quantizer (Eq. 4) | |
| # ============================================================================= | |
| def compute_lloyd_max_codebook(d: int, b: int, max_iter: int = 200, | |
| tol: float = 1e-12) -> np.ndarray: | |
| """ | |
| Compute optimal Lloyd-Max codebook for the Beta distribution on [-1, 1] | |
| induced by random rotation of unit vectors in R^d. | |
| This solves the continuous 1D k-means problem from Eq. (4): | |
| min_{c_1,...,c_{2^b}} sum_i integral |x - c_i|^2 * f_X(x) dx | |
| Uses iterative Lloyd-Max algorithm: | |
| 1. Given centroids, compute optimal boundaries (midpoints) | |
| 2. Given boundaries, compute optimal centroids (conditional means) | |
| 3. Repeat until convergence | |
| Args: | |
| d: Vector dimension | |
| b: Bit-width (number of bits per coordinate) | |
| max_iter: Maximum Lloyd-Max iterations | |
| tol: Convergence tolerance on centroid movement | |
| Returns: | |
| Sorted array of 2^b centroid values | |
| """ | |
| n_levels = 2**b | |
| # For high d, the distribution is approximately N(0, 1/d). | |
| # Initialize centroids uniformly in the range where most mass lives. | |
| sigma = 1.0 / np.sqrt(d) | |
| # Span ~3 sigma on each side | |
| centroids = np.linspace(-3 * sigma, 3 * sigma, n_levels) | |
| # Numerical integration grid — fine enough for good codebook quality | |
| n_grid = 10000 | |
| x_grid = np.linspace(-1 + 1e-10, 1 - 1e-10, n_grid) | |
| dx = x_grid[1] - x_grid[0] | |
| pdf_vals = beta_pdf(x_grid, d) | |
| for iteration in range(max_iter): | |
| # Step 1: Compute boundaries (midpoints between consecutive centroids) | |
| boundaries = np.concatenate([ | |
| [-1.0], | |
| 0.5 * (centroids[:-1] + centroids[1:]), | |
| [1.0] | |
| ]) | |
| # Step 2: Compute optimal centroids as conditional means within each bin | |
| new_centroids = np.zeros(n_levels) | |
| for i in range(n_levels): | |
| lo, hi = boundaries[i], boundaries[i + 1] | |
| mask = (x_grid >= lo) & (x_grid < hi) | |
| if not np.any(mask): | |
| # Empty bin — keep old centroid | |
| new_centroids[i] = centroids[i] | |
| continue | |
| weighted_x = np.sum(x_grid[mask] * pdf_vals[mask]) * dx | |
| total_weight = np.sum(pdf_vals[mask]) * dx | |
| if total_weight < 1e-20: | |
| new_centroids[i] = centroids[i] | |
| else: | |
| new_centroids[i] = weighted_x / total_weight | |
| # Check convergence | |
| shift = np.max(np.abs(new_centroids - centroids)) | |
| centroids = new_centroids | |
| if shift < tol: | |
| break | |
| return np.sort(centroids) | |
| def compute_codebook_mse(centroids: np.ndarray, d: int) -> float: | |
| """ | |
| Compute the MSE cost C(f_X, b) for a given codebook. | |
| This is d * C(f_X, b) = expected ||x - x_hat||^2 per Theorem 1. | |
| """ | |
| n_grid = 10000 | |
| x_grid = np.linspace(-1 + 1e-10, 1 - 1e-10, n_grid) | |
| dx = x_grid[1] - x_grid[0] | |
| pdf_vals = beta_pdf(x_grid, d) | |
| # For each grid point, find nearest centroid | |
| # Shape: (n_grid, n_levels) -> distances | |
| dists = np.abs(x_grid[:, None] - centroids[None, :]) | |
| nearest_idx = np.argmin(dists, axis=1) | |
| nearest_centroid = centroids[nearest_idx] | |
| # MSE per coordinate | |
| mse_per_coord = np.sum((x_grid - nearest_centroid)**2 * pdf_vals) * dx | |
| # Total MSE for d-dimensional vector | |
| return d * mse_per_coord | |
| # ============================================================================= | |
| # Section 3: Random Rotation Matrix Generation | |
| # ============================================================================= | |
| def generate_random_rotation(d: int, seed: Optional[int] = None) -> np.ndarray: | |
| """ | |
| Generate a random orthogonal rotation matrix via QR decomposition | |
| of a random Gaussian matrix. | |
| This is the Π matrix from Algorithm 1, line 2. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| G = rng.standard_normal((d, d)) | |
| Q, R = np.linalg.qr(G) | |
| # Ensure proper rotation (det = +1) by fixing sign ambiguity | |
| signs = np.sign(np.diag(R)) | |
| signs[signs == 0] = 1 | |
| Q = Q * signs[None, :] | |
| return Q | |
| # ============================================================================= | |
| # Section 4: QJL — Quantized Johnson-Lindenstrauss (Definition 1) | |
| # ============================================================================= | |
| class QJL: | |
| """ | |
| 1-bit Quantized Johnson-Lindenstrauss transform. | |
| From Definition 1: | |
| Q_qjl(x) = sign(S · x) | |
| Q_qjl^{-1}(z) = sqrt(π/2) / d · S^T · z | |
| Provides unbiased inner product estimates with zero memory overhead. | |
| """ | |
| def __init__(self, d: int, seed: Optional[int] = None): | |
| self.d = d | |
| rng = np.random.default_rng(seed) | |
| self.S = rng.standard_normal((d, d)) | |
| def quantize(self, x: np.ndarray) -> np.ndarray: | |
| """Quantize to sign bits: sign(S · x)""" | |
| projected = self.S @ x | |
| return np.sign(projected).astype(np.int8) | |
| def dequantize(self, z: np.ndarray, residual_norm: float) -> np.ndarray: | |
| """ | |
| Dequantize: sqrt(π/2) / d · γ · S^T · z | |
| where γ = ||residual||_2 | |
| """ | |
| scale = np.sqrt(np.pi / 2.0) / self.d * residual_norm | |
| return scale * (self.S.T @ z.astype(np.float64)) | |
| # ============================================================================= | |
| # Section 5: TurboQuant_mse — Algorithm 1 | |
| # ============================================================================= | |
| class TurboQuantMSE: | |
| """ | |
| MSE-optimal TurboQuant (Algorithm 1). | |
| Quantization: | |
| 1. Rotate: y = Π · x | |
| 2. For each coordinate y_j, find nearest centroid index | |
| Dequantization: | |
| 1. Replace indices with centroid values -> y_hat | |
| 2. Rotate back: x_hat = Π^T · y_hat | |
| """ | |
| def __init__(self, d: int, b: int, rotation_seed: Optional[int] = None, | |
| codebook: Optional[np.ndarray] = None): | |
| """ | |
| Args: | |
| d: Vector dimension | |
| b: Bit-width per coordinate | |
| rotation_seed: Seed for reproducible rotation matrix | |
| codebook: Pre-computed codebook (if None, computes via Lloyd-Max) | |
| """ | |
| self.d = d | |
| self.b = b | |
| self.n_levels = 2**b | |
| # Line 2: Generate random rotation matrix | |
| self.Pi = generate_random_rotation(d, seed=rotation_seed) | |
| # Line 3: Construct codebook via Lloyd-Max | |
| if codebook is not None: | |
| self.codebook = np.sort(codebook) | |
| else: | |
| print(f" Computing Lloyd-Max codebook (d={d}, b={b})...") | |
| t0 = time.time() | |
| self.codebook = compute_lloyd_max_codebook(d, b) | |
| print(f" Codebook computed in {time.time()-t0:.3f}s: {self.codebook}") | |
| def quantize(self, x: np.ndarray) -> Tuple[np.ndarray, float]: | |
| """ | |
| Algorithm 1, Lines 4-7: Quantize a vector. | |
| Args: | |
| x: Input vector of shape (d,). Need not be unit norm — | |
| we store the norm separately. | |
| Returns: | |
| (indices, norm): Quantized index array and original L2 norm | |
| """ | |
| norm = np.linalg.norm(x) | |
| if norm < 1e-15: | |
| return np.zeros(self.d, dtype=np.int32), 0.0 | |
| # Normalize to unit sphere | |
| x_unit = x / norm | |
| # Line 5: y = Π · x | |
| y = self.Pi @ x_unit | |
| # Line 6: Find nearest centroid for each coordinate | |
| # Shape: (d, n_levels) -> pick argmin per coordinate | |
| dists = np.abs(y[:, None] - self.codebook[None, :]) | |
| indices = np.argmin(dists, axis=1).astype(np.int32) | |
| return indices, norm | |
| def dequantize(self, indices: np.ndarray, norm: float) -> np.ndarray: | |
| """ | |
| Algorithm 1, Lines 8-11: Dequantize. | |
| Args: | |
| indices: Index array from quantize() | |
| norm: Original L2 norm | |
| Returns: | |
| Reconstructed vector of shape (d,) | |
| """ | |
| # Line 9: y_hat_j = c_{idx_j} | |
| y_hat = self.codebook[indices] | |
| # Line 10: x_hat = Π^T · y_hat | |
| x_hat = self.Pi.T @ y_hat | |
| # Rescale by original norm | |
| return x_hat * norm | |
| def compress_size_bits(self) -> int: | |
| """Total bits used to store one quantized vector (excluding norm).""" | |
| return self.d * self.b | |
| # ============================================================================= | |
| # Section 6: TurboQuant_prod — Algorithm 2 | |
| # ============================================================================= | |
| class TurboQuantProd: | |
| """ | |
| Inner-product-optimal TurboQuant (Algorithm 2). | |
| Two-stage approach: | |
| Stage 1: Apply TurboQuant_mse at (b-1) bits | |
| Stage 2: Apply QJL on the residual (1 bit per coordinate) | |
| This eliminates the inner product bias that MSE-optimal quantizers have. | |
| Total bit-width: b = (b-1) + 1 | |
| """ | |
| def __init__(self, d: int, b: int, rotation_seed: Optional[int] = None, | |
| qjl_seed: Optional[int] = None, | |
| codebook: Optional[np.ndarray] = None): | |
| """ | |
| Args: | |
| d: Vector dimension | |
| b: Total bit-width per coordinate (must be >= 2) | |
| rotation_seed: Seed for MSE quantizer rotation | |
| qjl_seed: Seed for QJL projection matrix | |
| codebook: Pre-computed codebook for the (b-1) MSE stage | |
| """ | |
| if b < 2: | |
| raise ValueError("TurboQuant_prod requires b >= 2 (1 bit for MSE + 1 for QJL)") | |
| self.d = d | |
| self.b = b | |
| # Line 2: Instantiate TurboQuant_mse with bit-width (b-1) | |
| self.mse_quant = TurboQuantMSE(d, b - 1, rotation_seed=rotation_seed, | |
| codebook=codebook) | |
| # Line 3: Generate QJL random projection matrix | |
| self.qjl = QJL(d, seed=qjl_seed) | |
| def quantize(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, float, float]: | |
| """ | |
| Algorithm 2, Lines 4-8: Quantize a vector. | |
| Returns: | |
| (mse_indices, qjl_signs, residual_norm, original_norm) | |
| """ | |
| original_norm = np.linalg.norm(x) | |
| if original_norm < 1e-15: | |
| return (np.zeros(self.d, dtype=np.int32), | |
| np.zeros(self.d, dtype=np.int8), | |
| 0.0, 0.0) | |
| # Line 5: MSE quantize | |
| mse_indices, norm = self.mse_quant.quantize(x) | |
| # Line 6: Compute residual r = x - DeQuant_mse(idx) | |
| x_mse_reconstructed = self.mse_quant.dequantize(mse_indices, norm) | |
| residual = x - x_mse_reconstructed | |
| residual_norm = np.linalg.norm(residual) | |
| # Line 7: QJL on residual | |
| if residual_norm < 1e-15: | |
| qjl_signs = np.zeros(self.d, dtype=np.int8) | |
| else: | |
| # Normalize residual before QJL (QJL expects unit vectors) | |
| qjl_signs = self.qjl.quantize(residual / residual_norm) | |
| # Line 8: output (idx, qjl, ||r||_2) | |
| return mse_indices, qjl_signs, residual_norm, original_norm | |
| def dequantize(self, mse_indices: np.ndarray, qjl_signs: np.ndarray, | |
| residual_norm: float, original_norm: float) -> np.ndarray: | |
| """ | |
| Algorithm 2, Lines 9-12: Dequantize. | |
| Returns: | |
| Reconstructed vector of shape (d,) | |
| """ | |
| # Line 10: x_hat_mse = DeQuant_mse(idx) | |
| x_mse = self.mse_quant.dequantize(mse_indices, original_norm) | |
| # Line 11: x_hat_qjl = sqrt(π/2)/d · γ · S^T · qjl | |
| x_qjl = self.qjl.dequantize(qjl_signs, residual_norm) | |
| # Line 12: output x_hat_mse + x_hat_qjl | |
| return x_mse + x_qjl | |
| def compress_size_bits(self) -> int: | |
| """Total bits per vector: (b-1)*d for MSE + d for QJL + 32 for norm.""" | |
| return self.d * self.b + 32 # +32 for the residual norm float | |
| # ============================================================================= | |
| # Section 7: Compression Ratio Calculator | |
| # ============================================================================= | |
| def compression_report(d: int, b: int, original_bits: int = 16) -> dict: | |
| """ | |
| Calculate compression ratio for KV cache quantization. | |
| Args: | |
| d: Head dimension (e.g., 128 for most modern transformers) | |
| b: TurboQuant bit-width | |
| original_bits: Original precision (16 for fp16/bf16) | |
| Returns: | |
| Dict with compression metrics | |
| """ | |
| original_size = d * original_bits | |
| turbo_size = d * b + 32 # +32 bits for norm storage | |
| ratio = original_size / turbo_size | |
| return { | |
| "dimension": d, | |
| "bit_width": b, | |
| "original_bits_per_vector": original_size, | |
| "turboquant_bits_per_vector": turbo_size, | |
| "compression_ratio": f"{ratio:.2f}x", | |
| "memory_fraction": f"{turbo_size/original_size:.1%}", | |
| } | |
| # ============================================================================= | |
| # Section 8: Test Suite | |
| # ============================================================================= | |
| def run_tests(): | |
| """ | |
| Validate TurboQuant against the paper's theoretical predictions. | |
| """ | |
| print("=" * 70) | |
| print("TurboQuant Prototype — Validation Suite") | |
| print("Reference: Zandieh et al., ICLR 2026 (arXiv:2504.19874)") | |
| print("=" * 70) | |
| # ------------------------------------------------------------------------- | |
| # Test 1: Codebook quality — compare MSE to paper's Table (Theorem 1) | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 1] Lloyd-Max Codebook Quality vs Paper's Theorem 1") | |
| print("-" * 50) | |
| # Paper's expected MSE values for b=1,2,3,4 (from Theorem 1): | |
| expected_mse = {1: 0.36, 2: 0.117, 3: 0.03, 4: 0.009} | |
| d_test = 128 # Typical transformer head dimension | |
| codebooks = {} | |
| for b in [1, 2, 3, 4]: | |
| cb = compute_lloyd_max_codebook(d_test, b) | |
| mse = compute_codebook_mse(cb, d_test) | |
| codebooks[b] = cb | |
| paper_val = expected_mse[b] | |
| ratio = mse / paper_val | |
| status = "✓" if 0.5 < ratio < 2.0 else "✗" | |
| print(f" b={b}: MSE={mse:.6f} (paper≈{paper_val}) ratio={ratio:.3f} {status}") | |
| # ------------------------------------------------------------------------- | |
| # Test 2: TurboQuant_mse round-trip | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 2] TurboQuant_mse Round-trip (d=128)") | |
| print("-" * 50) | |
| rng = np.random.default_rng(42) | |
| n_vectors = 1000 | |
| for b in [2, 3, 4]: | |
| quant = TurboQuantMSE(d_test, b, rotation_seed=42, codebook=codebooks[b]) | |
| total_mse = 0.0 | |
| for _ in range(n_vectors): | |
| x = rng.standard_normal(d_test) | |
| x = x / np.linalg.norm(x) # Unit vector | |
| idx, norm = quant.quantize(x) | |
| x_hat = quant.dequantize(idx, norm) | |
| total_mse += np.sum((x - x_hat)**2) | |
| avg_mse = total_mse / n_vectors | |
| paper_val = expected_mse[b] | |
| print(f" b={b}: Avg MSE={avg_mse:.6f} (paper≈{paper_val})") | |
| # ------------------------------------------------------------------------- | |
| # Test 3: TurboQuant_prod unbiasedness | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 3] TurboQuant_prod Inner Product Unbiasedness (d=128, b=3)") | |
| print("-" * 50) | |
| b = 3 | |
| quant_prod = TurboQuantProd(d_test, b, rotation_seed=42, qjl_seed=99, | |
| codebook=codebooks[b - 1]) | |
| # Test: E[<y, x_hat>] should equal <y, x> | |
| n_trials = 500 | |
| bias_samples = [] | |
| for _ in range(n_trials): | |
| x = rng.standard_normal(d_test) | |
| x = x / np.linalg.norm(x) | |
| y = rng.standard_normal(d_test) | |
| true_ip = np.dot(y, x) | |
| # Average over multiple quantization rounds (randomness in rotation/QJL) | |
| # For a single instance, the rotation is fixed, so we just measure once | |
| idx, qjl_signs, r_norm, o_norm = quant_prod.quantize(x) | |
| x_hat = quant_prod.dequantize(idx, qjl_signs, r_norm, o_norm) | |
| est_ip = np.dot(y, x_hat) | |
| bias_samples.append(est_ip - true_ip) | |
| mean_bias = np.mean(bias_samples) | |
| std_bias = np.std(bias_samples) | |
| print(f" Mean bias: {mean_bias:.6f} (should be ≈0)") | |
| print(f" Std of error: {std_bias:.6f}") | |
| print(f" |Mean bias| / Std: {abs(mean_bias)/std_bias:.4f} (should be small)") | |
| # ------------------------------------------------------------------------- | |
| # Test 4: Compression ratios for ng-01 relevant scenarios | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 4] Compression Ratios — KV Cache Scenarios") | |
| print("-" * 50) | |
| # Qwen3.5-27B: head_dim=128, typical for modern transformers | |
| # 70B models: also typically head_dim=128 | |
| for b in [2, 3, 4]: | |
| report = compression_report(d=128, b=b, original_bits=16) | |
| print(f" b={b}: {report['compression_ratio']} compression " | |
| f"({report['memory_fraction']} of original)") | |
| # ------------------------------------------------------------------------- | |
| # Test 5: Quantization speed benchmark | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 5] Quantization Speed (d=128, n=10000 vectors)") | |
| print("-" * 50) | |
| quant_mse = TurboQuantMSE(d_test, 3, rotation_seed=42, codebook=codebooks[3]) | |
| vectors = rng.standard_normal((10000, d_test)) | |
| vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True) | |
| # MSE quantize | |
| t0 = time.time() | |
| for i in range(len(vectors)): | |
| quant_mse.quantize(vectors[i]) | |
| t_mse = time.time() - t0 | |
| print(f" TurboQuant_mse (b=3): {t_mse:.3f}s " | |
| f"({len(vectors)/t_mse:.0f} vectors/sec)") | |
| # ------------------------------------------------------------------------- | |
| # Test 6: Simulated KV cache compression for Qwen3.5-27B | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 6] Simulated KV Cache — Qwen3.5-27B Dense") | |
| print("-" * 50) | |
| # Qwen3.5-27B approximate KV cache params: | |
| # - num_layers: ~32 (estimated for 27B dense) | |
| # - num_kv_heads: ~8 (GQA typical) | |
| # - head_dim: 128 | |
| # - For 8K context: 8192 tokens | |
| n_layers = 32 | |
| n_kv_heads = 8 | |
| head_dim = 128 | |
| context_lengths = [4096, 8192, 16384, 32768] | |
| print(f" Model: Qwen3.5-27B Dense (est. {n_layers}L, {n_kv_heads} KV heads, " | |
| f"d={head_dim})") | |
| print(f" KV cache = 2 (K+V) × layers × kv_heads × seq_len × head_dim × precision") | |
| print() | |
| for ctx_len in context_lengths: | |
| # FP16 baseline | |
| fp16_bytes = 2 * n_layers * n_kv_heads * ctx_len * head_dim * 2 # 2 bytes per fp16 | |
| fp16_gb = fp16_bytes / (1024**3) | |
| # TurboQuant at 3.5 bits (paper's quality-neutral setting) | |
| tq_bits_per_val = 3.5 | |
| tq_bytes = 2 * n_layers * n_kv_heads * ctx_len * head_dim * tq_bits_per_val / 8 | |
| tq_gb = tq_bytes / (1024**3) | |
| ratio = fp16_gb / tq_gb | |
| print(f" {ctx_len:>6} tokens: FP16={fp16_gb:.2f} GB → TQ@3.5b={tq_gb:.2f} GB " | |
| f"({ratio:.1f}x compression)") | |
| # ------------------------------------------------------------------------- | |
| # Test 7: 70B model projection across 3x RTX 3090 | |
| # ------------------------------------------------------------------------- | |
| print("\n[Test 7] Projected KV Cache — 70B Model on 3× RTX 3090 (72GB)") | |
| print("-" * 50) | |
| # 70B model approximate params (Llama-3.1-70B style): | |
| n_layers_70b = 80 | |
| n_kv_heads_70b = 8 # GQA | |
| head_dim_70b = 128 | |
| # Weight memory estimate: 70B params at Q4 ≈ ~35-40GB | |
| weight_gb = 38.0 # Conservative Q4 estimate | |
| total_vram = 72.0 | |
| available_for_kv = total_vram - weight_gb | |
| print(f" Model: 70B (Q4_K_M ≈ {weight_gb:.0f} GB weights)") | |
| print(f" Total VRAM: {total_vram:.0f} GB") | |
| print(f" Available for KV cache: {available_for_kv:.0f} GB") | |
| print() | |
| for tq_bits in [16.0, 3.5, 2.5]: | |
| bytes_per_token = (2 * n_layers_70b * n_kv_heads_70b * head_dim_70b | |
| * tq_bits / 8) | |
| gb_per_token = bytes_per_token / (1024**3) | |
| max_tokens = int(available_for_kv / gb_per_token) | |
| label = "FP16" if tq_bits == 16.0 else f"TQ@{tq_bits}b" | |
| print(f" {label:>8}: {bytes_per_token:.0f} bytes/token → " | |
| f"~{max_tokens:,} tokens max context") | |
| print() | |
| print("=" * 70) | |
| print("All tests complete.") | |
| print("=" * 70) | |
| # ============================================================================= | |
| # Section 9: Codebook Export (for future C/CUDA integration) | |
| # ============================================================================= | |
| def export_codebooks(d: int, bit_widths: list = [1, 2, 3, 4], | |
| output_path: str = "turboquant_codebooks.json") -> dict: | |
| """ | |
| Precompute and export Lloyd-Max codebooks for use in C/CUDA implementations. | |
| The codebooks only depend on (d, b), so they can be computed once and | |
| embedded as constants in ik_llama.cpp. | |
| """ | |
| result = {"dimension": d, "codebooks": {}} | |
| for b in bit_widths: | |
| print(f"Computing codebook for d={d}, b={b}...") | |
| cb = compute_lloyd_max_codebook(d, b) | |
| mse = compute_codebook_mse(cb, d) | |
| result["codebooks"][str(b)] = { | |
| "bit_width": b, | |
| "n_levels": 2**b, | |
| "centroids": cb.tolist(), | |
| "expected_mse": float(mse), | |
| } | |
| with open(output_path, 'w') as f: | |
| json.dump(result, f, indent=2) | |
| print(f"\nCodebooks exported to: {output_path}") | |
| return result | |
| # ============================================================================= | |
| # Main | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "dimension": 128, | |
| "codebooks": { | |
| "1": { | |
| "bit_width": 1, | |
| "n_levels": 2, | |
| "centroids": [ | |
| -0.07066158769779213, | |
| 0.07066158769779213 | |
| ], | |
| "expected_mse": 0.36088832307541935 | |
| }, | |
| "2": { | |
| "bit_width": 2, | |
| "n_levels": 4, | |
| "centroids": [ | |
| -0.13311451677280386, | |
| -0.0400274664834152, | |
| 0.04002746648341517, | |
| 0.1331145167728038 | |
| ], | |
| "expected_mse": 0.11599992634243472 | |
| }, | |
| "3": { | |
| "bit_width": 3, | |
| "n_levels": 8, | |
| "centroids": [ | |
| -0.18904037194348838, | |
| -0.11879501670185091, | |
| -0.06702922184405663, | |
| -0.02174971334976657, | |
| 0.02174971334976654, | |
| 0.0670292218440566, | |
| 0.11879501670185087, | |
| 0.18904037194348833 | |
| ], | |
| "expected_mse": 0.03396922058758807 | |
| }, | |
| "4": { | |
| "bit_width": 4, | |
| "n_levels": 16, | |
| "centroids": [ | |
| -0.239612533071387, | |
| -0.18317108415643454, | |
| -0.14430970076906538, | |
| -0.11276586366299288, | |
| -0.08507481024405737, | |
| -0.05962130616889217, | |
| -0.035390176872708554, | |
| -0.01173284981923122, | |
| 0.0117328498192312, | |
| 0.035390176872708505, | |
| 0.05962130616889214, | |
| 0.0850748102440573, | |
| 0.11276586366299284, | |
| 0.14430970076906535, | |
| 0.1831710841564345, | |
| 0.23961253307138697 | |
| ], | |
| "expected_mse": 0.009329931321633787 | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment