← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_quant.h
Go to the documentation of this file.
1 /**
2  * @file ckernel_quant.h
3  * @brief Quantization block structures for weight-only quantization
4  *
5  * Defines block structures for various quantization formats used in LLM inference.
6  * Primary focus on Q4_K_M which is commonly used for LLM weight compression.
7  *
8  * Block structures are compatible with llama.cpp/GGML for model loading.
9  */
10 
11 #ifndef CKERNEL_QUANT_H
12 #define CKERNEL_QUANT_H
13 
14 #include <stdint.h>
15 #include <stddef.h>
16 #include "ckernel_dtype.h"
17 
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21 
22 /* ============================================================================
23  * Half-Precision Type (FP16 - IEEE 754)
24  * ============================================================================ */
25 
26 typedef uint16_t ck_half;
27 
28 /* ============================================================================
29  * Q4_0: Simple 4-bit Quantization
30  * - 32 weights per block
31  * - 1 FP16 scale per block
32  * - 18 bytes per 32 weights = 4.5 bits/weight
33  * ============================================================================ */
34 
35 #define QK4_0 32
36 
37 typedef struct {
38  ck_half d; /* 2 bytes: scale (delta) */
39  uint8_t qs[QK4_0 / 2]; /* 16 bytes: 32 x 4-bit weights (2 per byte) */
40 } block_q4_0;
41 /* Total: 18 bytes per 32 weights */
42 
43 /* ============================================================================
44  * Q4_1: Simple 4-bit Quantization with Min
45  * - 32 weights per block
46  * - 2 FP16 values: scale (d) and min (m)
47  * - 20 bytes per 32 weights = 5.0 bits/weight
48  * ============================================================================ */
49 
50 #define QK4_1 32
51 
52 typedef struct {
53  ck_half d; /* 2 bytes: scale (delta) */
54  ck_half m; /* 2 bytes: minimum */
55  uint8_t qs[QK4_1 / 2]; /* 16 bytes: 32 x 4-bit weights (2 per byte) */
56 } block_q4_1;
57 /* Total: 20 bytes per 32 weights */
58 
59 /* ============================================================================
60  * Q5_0: Simple 5-bit Quantization
61  * - 32 weights per block
62  * - 1 FP16 scale per block
63  * - Low 4 bits stored like Q4_0, high 1 bit packed separately
64  * - 22 bytes per 32 weights = 5.5 bits/weight
65  * ============================================================================ */
66 
67 #define QK5_0 32
68 
69 typedef struct {
70  ck_half d; /* 2 bytes: scale (delta) */
71  uint8_t qh[4]; /* 4 bytes: high 1-bit of each weight (32 bits total) */
72  uint8_t qs[QK5_0 / 2]; /* 16 bytes: low 4-bits of 32 weights (2 per byte) */
73 } block_q5_0;
74 /* Total: 22 bytes per 32 weights */
75 
76 /* ============================================================================
77  * Q5_1: Simple 5-bit Quantization with Min
78  * - 32 weights per block
79  * - 2 FP16 values: scale (d) and min (m)
80  * - Low 4 bits stored like Q4_1, high 1 bit packed separately
81  * - 24 bytes per 32 weights = 6.0 bits/weight
82  * ============================================================================ */
83 
84 #define QK5_1 32
85 
86 typedef struct {
87  ck_half d; /* 2 bytes: scale (delta) */
88  ck_half m; /* 2 bytes: minimum */
89  uint8_t qh[4]; /* 4 bytes: high 1-bit of each weight (32 bits total) */
90  uint8_t qs[QK5_1 / 2]; /* 16 bytes: low 4-bits of 32 weights (2 per byte) */
91 } block_q5_1;
92 /* Total: 24 bytes per 32 weights */
93 
94 /* ============================================================================
95  * Q8_0: Simple 8-bit Quantization
96  * - 32 weights per block
97  * - 1 FP16 scale per block
98  * - 34 bytes per 32 weights = 8.5 bits/weight
99  * ============================================================================ */
100 
101 #define QK8_0 32
102 
103 typedef struct {
104  ck_half d; /* 2 bytes: scale */
105  int8_t qs[QK8_0]; /* 32 bytes: 32 x 8-bit signed weights */
106 } block_q8_0;
107 /* Total: 34 bytes per 32 weights */
108 
109 /* ============================================================================
110  * Q4_K: K-Quant 4-bit with Nested Scales (Primary Target)
111  * - 256 weights per super-block
112  * - 8 sub-blocks of 32 weights each
113  * - Two-level scaling: super-block FP16 + sub-block 6-bit
114  * - 144 bytes per 256 weights = 4.5 bits/weight
115  *
116  * This is the format used by Q4_K_M, Q4_K_S, Q4_K_L variants.
117  * The M/S/L suffix indicates quantization aggressiveness, not structure.
118  * ============================================================================ */
119 
120 #define QK_K 256
121 #define K_SCALE_SIZE 12
122 
123 typedef struct {
124  ck_half d; /* 2 bytes: super-block scale */
125  ck_half dmin; /* 2 bytes: super-block minimum */
126  uint8_t scales[K_SCALE_SIZE]; /* 12 bytes: 8 sub-block scales + 8 sub-block mins (6-bit packed) */
127  uint8_t qs[QK_K / 2]; /* 128 bytes: 256 x 4-bit weights */
128 } block_q4_K;
129 /* Total: 144 bytes per 256 weights */
130 
131 /* ============================================================================
132  * Q5_K: K-Quant 5-bit with Nested Scales
133  * - 256 weights per super-block
134  * - 8 sub-blocks of 32 weights each
135  * - Two-level scaling: super-block FP16 + sub-block 6-bit
136  * - High bit per weight stored separately (1 bit each)
137  * - 176 bytes per 256 weights = 5.5 bits/weight
138  * ============================================================================ */
139 
140 typedef struct {
141  ck_half d; /* 2 bytes: super-block scale */
142  ck_half dmin; /* 2 bytes: super-block minimum */
143  uint8_t scales[K_SCALE_SIZE]; /* 12 bytes: 8 sub-block scales + 8 sub-block mins (6-bit packed) */
144  uint8_t qh[QK_K / 8]; /* 32 bytes: high 1-bit for 256 weights */
145  uint8_t qs[QK_K / 2]; /* 128 bytes: 256 x 4-bit weights */
146 } block_q5_K;
147 /* Total: 176 bytes per 256 weights */
148 
149 /* ============================================================================
150  * Q6_K: K-Quant 6-bit (per-16 scales)
151  * - 256 weights per block
152  * - 16 sub-blocks of 16 weights each
153  * - Stored as low 4 bits (ql) + high 2 bits (qh) + int8 scales
154  * ============================================================================ */
155 
156 typedef struct {
157  uint8_t ql[QK_K / 2]; /* 128 bytes: low 4 bits */
158  uint8_t qh[QK_K / 4]; /* 64 bytes: high 2 bits */
159  int8_t scales[QK_K / 16]; /* 16 bytes: 16 sub-block scales */
160  ck_half d; /* 2 bytes: super-block scale */
161 } block_q6_K;
162 /* Total: 210 bytes per 256 weights */
163 
164 /* ============================================================================
165  * Q8_K: K-Quant 8-bit (used for activations in some ops)
166  * - 256 weights per super-block
167  * - 1 FP32 scale per block (not FP16 like others!)
168  * ============================================================================ */
169 
170 typedef struct {
171  float d; /* 4 bytes: scale */
172  int8_t qs[QK_K]; /* 256 bytes: 256 x 8-bit signed weights */
173  int16_t bsums[QK_K / 16]; /* 32 bytes: block sums for optimization */
174 } block_q8_K;
175 /* Total: 292 bytes per 256 weights */
176 
177 /* ============================================================================
178  * Size Calculation Utilities
179  * ============================================================================ */
180 
181 /**
182  * @brief Get the block size (number of weights per block) for a quant type
183  */
184 static inline size_t ck_quant_block_size(int type) {
185  switch (type) {
186  case 0: return QK4_0; /* Q4_0 */
187  case 1: return QK8_0; /* Q8_0 */
188  case 2: return QK_K; /* Q4_K */
189  case 3: return QK_K; /* Q8_K */
190  case CK_DT_Q4_1: return QK4_1;
191  case CK_DT_Q5_0: return QK5_0;
192  case CK_DT_Q5_1: return QK5_1;
193  case CK_DT_Q5_K: return QK_K;
194  case CK_DT_Q6_K: return QK_K;
195  default: return 1;
196  }
197 }
198 
199 /**
200  * @brief Get the byte size per block for a quant type
201  */
202 static inline size_t ck_quant_type_size(int type) {
203  switch (type) {
204  case 0: return sizeof(block_q4_0);
205  case 1: return sizeof(block_q8_0);
206  case 2: return sizeof(block_q4_K);
207  case 3: return sizeof(block_q8_K);
208  case CK_DT_Q4_1: return sizeof(block_q4_1);
209  case CK_DT_Q5_0: return sizeof(block_q5_0);
210  case CK_DT_Q5_1: return sizeof(block_q5_1);
211  case CK_DT_Q5_K: return sizeof(block_q5_K);
212  case CK_DT_Q6_K: return sizeof(block_q6_K);
213  default: return 4; /* FP32 */
214  }
215 }
216 
217 /**
218  * @brief Calculate total bytes needed for n_elements with given quant type
219  */
220 static inline size_t ck_quant_row_size(int type, int64_t n_elements) {
221  size_t block_size = ck_quant_block_size(type);
222  size_t type_size = ck_quant_type_size(type);
223  return (n_elements / block_size) * type_size;
224 }
225 
226 /* ============================================================================
227  * Q4_K Scale Unpacking Utilities
228  *
229  * The scales[12] array packs 8 scales and 8 mins in 6-bit format.
230  * Unpacking is non-trivial due to the bit packing.
231  * ============================================================================ */
232 
233 /**
234  * @brief Unpack Q4_K sub-block scales and mins
235  *
236  * @param scales The packed scales[12] array from block_q4_K
237  * @param sc Output: 8 unpacked scale values (multiply by super-block d)
238  * @param m Output: 8 unpacked min values (multiply by super-block dmin)
239  *
240  * This matches llama.cpp's get_scale_min_k4() function exactly.
241  * The 12-byte scales array layout:
242  * - bytes 0-3: 6-bit scales[0-3] (high 2 bits used for scales[4-7])
243  * - bytes 4-7: 6-bit mins[0-3] (high 2 bits used for mins[4-7])
244  * - bytes 8-11: low 4 bits for scales[4-7], high 4 bits for mins[4-7]
245  */
246 static inline void unpack_q4_k_scales(const uint8_t *scales,
247  uint8_t *sc, uint8_t *m) {
248  /* Direct 6-bit values for indices 0-3 */
249  sc[0] = scales[0] & 0x3F;
250  sc[1] = scales[1] & 0x3F;
251  sc[2] = scales[2] & 0x3F;
252  sc[3] = scales[3] & 0x3F;
253 
254  m[0] = scales[4] & 0x3F;
255  m[1] = scales[5] & 0x3F;
256  m[2] = scales[6] & 0x3F;
257  m[3] = scales[7] & 0x3F;
258 
259  /* 6-bit values for indices 4-7: low 4 bits from bytes 8-11,
260  * high 2 bits from upper bits of bytes 0-3 (scales) and 4-7 (mins) */
261  sc[4] = (scales[8] & 0x0F) | ((scales[0] >> 6) << 4);
262  sc[5] = (scales[9] & 0x0F) | ((scales[1] >> 6) << 4);
263  sc[6] = (scales[10] & 0x0F) | ((scales[2] >> 6) << 4);
264  sc[7] = (scales[11] & 0x0F) | ((scales[3] >> 6) << 4);
265 
266  m[4] = (scales[8] >> 4) | ((scales[4] >> 6) << 4);
267  m[5] = (scales[9] >> 4) | ((scales[5] >> 6) << 4);
268  m[6] = (scales[10] >> 4) | ((scales[6] >> 6) << 4);
269  m[7] = (scales[11] >> 4) | ((scales[7] >> 6) << 4);
270 }
271 
272 /**
273  * @brief Unpack Q5_K sub-block scales and mins
274  *
275  * @param scales The packed scales[12] array from block_q5_K
276  * @param sc Output: 8 unpacked scale values (multiply by super-block d)
277  * @param m Output: 8 unpacked min values (multiply by super-block dmin)
278  *
279  * Q5_K uses the same 6-bit packed format as Q4_K for scales/mins.
280  * The 12-byte scales array layout is identical:
281  * - bytes 0-3: 6-bit scales[0-3] (high 2 bits used for scales[4-7])
282  * - bytes 4-7: 6-bit mins[0-3] (high 2 bits used for mins[4-7])
283  * - bytes 8-11: low 4 bits for scales[4-7], high 4 bits for mins[4-7]
284  */
285 static inline void unpack_q5_k_scales(const uint8_t *scales,
286  uint8_t *sc, uint8_t *m) {
287  /* Q5_K uses identical packing as Q4_K for scales/mins */
288  unpack_q4_k_scales(scales, sc, m);
289 }
290 
291 /* ============================================================================
292  * FP16 Conversion Utilities
293  *
294  * Three variants:
295  * _soft - Pure C bit manipulation (always available, portable)
296  * _simd - F16C hardware instruction (vcvtph2ps/vcvtps2ph, Ivy Bridge+)
297  * (default) - Auto-dispatches to best available at compile time
298  * ============================================================================ */
299 
300 /**
301  * @brief Convert FP16 (ck_half) to FP32 — software implementation
302  */
303 static inline float ck_fp16_to_fp32_soft(ck_half h) {
304  uint32_t sign = (h & 0x8000) << 16;
305  uint32_t exp = (h >> 10) & 0x1F;
306  uint32_t mant = h & 0x3FF;
307 
308  uint32_t result;
309 
310  if (exp == 0) {
311  if (mant == 0) {
312  result = sign;
313  } else {
314  /* Denormalized - convert to normalized FP32 */
315  exp = 1;
316  while ((mant & 0x400) == 0) {
317  mant <<= 1;
318  exp--;
319  }
320  mant &= 0x3FF;
321  result = sign | ((exp + 127 - 15) << 23) | (mant << 13);
322  }
323  } else if (exp == 31) {
324  result = sign | 0x7F800000 | (mant << 13);
325  } else {
326  result = sign | ((exp + 127 - 15) << 23) | (mant << 13);
327  }
328 
329  union { uint32_t u; float f; } u;
330  u.u = result;
331  return u.f;
332 }
333 
334 /**
335  * @brief Convert FP32 to FP16 (ck_half) — software implementation
336  */
337 static inline ck_half ck_fp32_to_fp16_soft(float f) {
338  union { uint32_t u; float f; } u;
339  u.f = f;
340 
341  uint32_t sign = (u.u >> 16) & 0x8000;
342  int32_t exp = ((u.u >> 23) & 0xFF) - 127 + 15;
343  uint32_t mant = (u.u >> 13) & 0x3FF;
344 
345  if (exp <= 0) {
346  if (exp < -10) {
347  return sign;
348  }
349  mant = (mant | 0x400) >> (1 - exp);
350  return sign | mant;
351  } else if (exp >= 31) {
352  return sign | 0x7C00;
353  }
354 
355  return sign | (exp << 10) | mant;
356 }
357 
358 /* --------------------------------------------------------------------------
359  * F16C Hardware SIMD conversion (requires Intel Ivy Bridge+ or AMD Piledriver+)
360  * Uses vcvtsh2ss / vcvtss2sh single-element hardware instructions.
361  * -------------------------------------------------------------------------- */
362 #if defined(__F16C__)
363 #include <immintrin.h>
364 
365 /**
366  * @brief Convert FP16 to FP32 — F16C hardware (1 instruction: vcvtsh2ss)
367  */
368 static inline float ck_fp16_to_fp32_simd(ck_half h) {
369  return _cvtsh_ss(h);
370 }
371 
372 /**
373  * @brief Convert FP32 to FP16 — F16C hardware (1 instruction: vcvtss2sh)
374  */
375 static inline ck_half ck_fp32_to_fp16_simd(float f) {
376  return (ck_half)_cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
377 }
378 #endif /* __F16C__ */
379 
380 /* --------------------------------------------------------------------------
381  * Default dispatch: selects hardware SIMD when available, else software
382  * -------------------------------------------------------------------------- */
383 static inline float ck_fp16_to_fp32(ck_half h) {
384 #if defined(__F16C__)
385  return ck_fp16_to_fp32_simd(h);
386 #else
387  return ck_fp16_to_fp32_soft(h);
388 #endif
389 }
390 
391 static inline ck_half ck_fp32_to_fp16(float f) {
392 #if defined(__F16C__)
393  return ck_fp32_to_fp16_simd(f);
394 #else
395  return ck_fp32_to_fp16_soft(f);
396 #endif
397 }
398 
399 /* Convenience macros */
400 #define CK_FP16_TO_FP32(x) ck_fp16_to_fp32(x)
401 #define CK_FP32_TO_FP16(x) ck_fp32_to_fp16(x)
402 #define CK_FP16_TO_FP32_SIMD(x) ck_fp16_to_fp32_simd(x)
403 #define CK_FP32_TO_FP16_SIMD(x) ck_fp32_to_fp16_simd(x)
404 #define CK_FP16_TO_FP32_SOFT(x) ck_fp16_to_fp32_soft(x)
405 #define CK_FP32_TO_FP16_SOFT(x) ck_fp32_to_fp16_soft(x)
406 
407 /* Legacy compatibility (for files that used the old names) */
409 #define ggml_fp16_to_fp32 ck_fp16_to_fp32
410 #define ggml_fp32_to_fp16 ck_fp32_to_fp16
411 #define GGML_FP16_TO_FP32 CK_FP16_TO_FP32
412 #define GGML_FP32_TO_FP16 CK_FP32_TO_FP16
413 
414 /* ============================================================================
415  * SSE Optimized Kernels
416  * ============================================================================ */
417 
418 void gemm_nt_q5_0_sse_v2(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
419 void gemm_nt_q6_k_sse(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
420 void gemm_nt_q6_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
421 void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K);
422 void quantize_row_q8_k_sse(const float *x, void *vy, int k);
423 void rmsnorm_q8_k_fused(const float *input, const float *gamma, void *vy, int tokens, int d_model, int aligned_embed_dim, float eps);
424 
425 /* Q5_K kernels (reference implementation) */
426 void gemv_q5_k_ref(float *y, const void *W, const float *x, int M, int K);
427 void gemm_nt_q5_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
428 void gemv_q5_k(float *y, const void *W, const float *x, int M, int K);
429 void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
430 
431 /* INT8 activation batch GEMM kernels (Q5_0 weights x Q8_0 activations) */
432 void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K);
433 void gemm_nt_q5_0_q8_0_unroll_avx(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K);
434 void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy);
435 void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy);
436 void quantize_row_q8_0(const float *x, void *vy, int k);
437 
438 #ifdef __cplusplus
439 }
440 #endif
441 
442 #endif /* CKERNEL_QUANT_H */
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_Q5_K
Definition: ckernel_dtype.h:46
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
@ CK_DT_Q4_1
Definition: ckernel_dtype.h:39
@ CK_DT_Q5_1
Definition: ckernel_dtype.h:45
void gemm_nt_q5_0_sse_v2(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
#define QK5_0
Definition: ckernel_quant.h:67
void gemv_q5_k_ref(float *y, const void *W, const float *x, int M, int K)
static float ck_fp16_to_fp32_soft(ck_half h)
Convert FP16 (ck_half) to FP32 — software implementation.
#define K_SCALE_SIZE
uint16_t ck_half
Definition: ckernel_quant.h:26
void gemm_nt_q6_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
#define QK5_1
Definition: ckernel_quant.h:84
void gemm_nt_q6_k_sse(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
static size_t ck_quant_type_size(int type)
Get the byte size per block for a quant type.
void rmsnorm_q8_k_fused(const float *input, const float *gamma, void *vy, int tokens, int d_model, int aligned_embed_dim, float eps)
#define QK4_0
Definition: ckernel_quant.h:35
void gemm_nt_q5_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
static ck_half ck_fp32_to_fp16(float f)
#define QK4_1
Definition: ckernel_quant.h:50
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
ck_half ggml_half
void gemm_nt_q5_0_q8_0_unroll_avx(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
static float ck_fp16_to_fp32(ck_half h)
static void unpack_q5_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q5_K sub-block scales and mins.
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
static size_t ck_quant_block_size(int type)
Get the block size (number of weights per block) for a quant type.
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
static size_t ck_quant_row_size(int type, int64_t n_elements)
Calculate total bytes needed for n_elements with given quant type.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
static ck_half ck_fp32_to_fp16_soft(float f)
Convert FP32 to FP16 (ck_half) — software implementation.
#define QK8_0
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)
#define QK_K
#define C(color)
Definition: show_config.c:39
ck_half d
Definition: ckernel_quant.h:38
ck_half m
Definition: ckernel_quant.h:54
ck_half d
Definition: ckernel_quant.h:53
ck_half dmin
ck_half d
Definition: ckernel_quant.h:70
ck_half m
Definition: ckernel_quant.h:88
ck_half d
Definition: ckernel_quant.h:87
ck_half dmin