← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q5_1.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q5_1.c
3  * @brief GEMM/GEMV kernels with Q5_1 quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Q5_1 Format:
15  * - 32 weights per block
16  * - 1 FP16 scale (d) per block
17  * - 1 FP16 minimum (m) per block
18  * - Low 4-bits stored like Q4_1 (16 bytes)
19  * - High 1-bit packed separately (4 bytes)
20  * - 24 bytes per 32 weights = 6.0 bits/weight
21  *
22  * Dequantization: w = d * q5 + m
23  * where q5 = low4bit | (highbit << 4), giving values 0-31
24  *
25  * Operations:
26  * Forward: Y = W @ X (W is Q5_1, X and Y are FP32)
27  * Backward: dX = W^T @ dY (gradient w.r.t. input)
28  */
29 
30 #include <stdint.h>
31 #include <stddef.h>
32 #include <string.h>
33 #include "ckernel_quant.h"
34 
35 #ifdef __AVX512F__
36 #include <immintrin.h>
37 #endif
38 
39 /* ============================================================================
40  * Forward Pass: GEMV y = W @ x
41  * ============================================================================ */
42 
43 /**
44  * @brief Matrix-vector multiply with Q5_1 weights (scalar reference)
45  *
46  * @param y Output vector [M]
47  * @param W Weight matrix in Q5_1 format [M x K]
48  * @param x Input vector [K]
49  * @param M Number of output rows
50  * @param K Number of columns (must be multiple of 32)
51  */
52 void gemv_q5_1_ref(float *y,
53  const void *W,
54  const float *x,
55  int M, int K)
56 {
57  const block_q5_1 *blocks = (const block_q5_1 *)W;
58  const int blocks_per_row = K / QK5_1;
59 
60  for (int row = 0; row < M; row++) {
61  float sum = 0.0f;
62 
63  for (int b = 0; b < blocks_per_row; b++) {
64  const block_q5_1 *block = &blocks[row * blocks_per_row + b];
65  const float d = CK_FP16_TO_FP32(block->d);
66  const float m = CK_FP16_TO_FP32(block->m);
67  const float *xp = &x[b * QK5_1];
68 
69  /* Get high bits as 32-bit integer */
70  uint32_t qh;
71  memcpy(&qh, block->qh, sizeof(qh));
72 
73  /* GGML Q5_1 layout: weights 0-15 from LOW nibbles, 16-31 from HIGH nibbles.
74  * High bits: bits 0-15 of qh → first half, bits 16-31 → second half.
75  * NOT interleaved like Q4_0/Q4_1. */
76 
77  /* First 16 weights: low nibbles of qs[j], high bit from qh bits 0-15 */
78  for (int j = 0; j < QK5_1 / 2; j++) {
79  const int lo = (block->qs[j] & 0x0F);
80  const int hi = ((qh >> j) & 1) << 4;
81  const float w = d * (float)(lo | hi) + m;
82  sum += w * xp[j];
83  }
84 
85  /* Second 16 weights: high nibbles of qs[j], high bit from qh bits 16-31 */
86  for (int j = 0; j < QK5_1 / 2; j++) {
87  const int lo = (block->qs[j] >> 4);
88  const int hi = ((qh >> (j + 16)) & 1) << 4;
89  const float w = d * (float)(lo | hi) + m;
90  sum += w * xp[j + QK5_1 / 2];
91  }
92  }
93 
94  y[row] = sum;
95  }
96 }
97 
98 #ifdef __AVX512F__
99 /**
100  * @brief Matrix-vector multiply with Q5_1 weights (AVX-512)
101  *
102  * GGML Q5_1 layout per block (32 weights):
103  * - Weights 0-15: low nibbles of qs[0..15], high bits from qh bits 0-15
104  * - Weights 16-31: high nibbles of qs[0..15], high bits from qh bits 16-31
105  */
106 void gemv_q5_1_avx512(float *y,
107  const void *W,
108  const float *x,
109  int M, int K)
110 {
111  const block_q5_1 *blocks = (const block_q5_1 *)W;
112  const int blocks_per_row = K / QK5_1;
113  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
114 
115  for (int row = 0; row < M; row++) {
116  __m512 acc = _mm512_setzero_ps();
117 
118  for (int b = 0; b < blocks_per_row; b++) {
119  const block_q5_1 *block = &blocks[row * blocks_per_row + b];
120  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
121  const __m512 vmin = _mm512_set1_ps(CK_FP16_TO_FP32(block->m));
122  const float *xp = &x[b * QK5_1];
123 
124  /* Load high bits */
125  uint32_t qh;
126  memcpy(&qh, block->qh, sizeof(qh));
127 
128  /* Load 16 bytes = 32 x 4-bit weights */
129  __m128i packed = _mm_loadu_si128((const __m128i *)block->qs);
130  __m512i bytes = _mm512_cvtepu8_epi32(packed);
131 
132  /* Extract low nibbles (weights 0-15) and high nibbles (weights 16-31) */
133  __m512i lo = _mm512_and_epi32(bytes, mask_lo);
134  __m512i hi_shift = _mm512_srli_epi32(bytes, 4);
135 
136  /* High bit contribution for first 16 weights: qh bits 0-15 */
137  __m512i qh_first = _mm512_set_epi32(
138  ((qh >> 15) & 1) << 4, ((qh >> 14) & 1) << 4,
139  ((qh >> 13) & 1) << 4, ((qh >> 12) & 1) << 4,
140  ((qh >> 11) & 1) << 4, ((qh >> 10) & 1) << 4,
141  ((qh >> 9) & 1) << 4, ((qh >> 8) & 1) << 4,
142  ((qh >> 7) & 1) << 4, ((qh >> 6) & 1) << 4,
143  ((qh >> 5) & 1) << 4, ((qh >> 4) & 1) << 4,
144  ((qh >> 3) & 1) << 4, ((qh >> 2) & 1) << 4,
145  ((qh >> 1) & 1) << 4, ((qh >> 0) & 1) << 4
146  );
147 
148  /* High bit contribution for second 16 weights: qh bits 16-31 */
149  __m512i qh_second = _mm512_set_epi32(
150  ((qh >> 31) & 1) << 4, ((qh >> 30) & 1) << 4,
151  ((qh >> 29) & 1) << 4, ((qh >> 28) & 1) << 4,
152  ((qh >> 27) & 1) << 4, ((qh >> 26) & 1) << 4,
153  ((qh >> 25) & 1) << 4, ((qh >> 24) & 1) << 4,
154  ((qh >> 23) & 1) << 4, ((qh >> 22) & 1) << 4,
155  ((qh >> 21) & 1) << 4, ((qh >> 20) & 1) << 4,
156  ((qh >> 19) & 1) << 4, ((qh >> 18) & 1) << 4,
157  ((qh >> 17) & 1) << 4, ((qh >> 16) & 1) << 4
158  );
159 
160  /* Combine low + high bits */
161  __m512i q_first = _mm512_or_epi32(lo, qh_first);
162  __m512i q_second = _mm512_or_epi32(hi_shift, qh_second);
163 
164  /* Dequantize: w = d * q + m */
165  __m512 w_first = _mm512_fmadd_ps(_mm512_cvtepi32_ps(q_first), vscale, vmin);
166  __m512 w_second = _mm512_fmadd_ps(_mm512_cvtepi32_ps(q_second), vscale, vmin);
167 
168  /* Load sequential input: first 16 elements, then next 16 */
169  __m512 x_first = _mm512_loadu_ps(&xp[0]);
170  __m512 x_second = _mm512_loadu_ps(&xp[16]);
171 
172  acc = _mm512_fmadd_ps(w_first, x_first, acc);
173  acc = _mm512_fmadd_ps(w_second, x_second, acc);
174  }
175 
176  y[row] = _mm512_reduce_add_ps(acc);
177  }
178 }
179 #endif
180 
181 /**
182  * @brief Auto-dispatch GEMV
183  */
184 void gemv_q5_1(float *y,
185  const void *W,
186  const float *x,
187  int M, int K)
188 {
189 #ifdef __AVX512F__
190  gemv_q5_1_avx512(y, W, x, M, K);
191 #else
192  gemv_q5_1_ref(y, W, x, M, K);
193 #endif
194 }
195 
196 /* ============================================================================
197  * Forward Pass: GEMM Y = W @ X
198  * ============================================================================ */
199 
200 /**
201  * @brief Matrix-matrix multiply with Q5_1 weights
202  */
203 void gemm_q5_1(float *Y,
204  const void *W,
205  const float *X,
206  int M, int N, int K)
207 {
208  for (int n = 0; n < N; n++) {
209  gemv_q5_1(&Y[n * M], W, &X[n * K], M, K);
210  }
211 }
212 
213 /* ============================================================================
214  * Backward Pass: Gradient w.r.t. Input
215  * ============================================================================ */
216 
217 /**
218  * @brief Backward pass: compute input gradient
219  *
220  * @param dX Output gradient w.r.t. input [K]
221  * @param W Weight matrix in Q5_1 format [M x K]
222  * @param dY Gradient w.r.t. output [M]
223  * @param M Number of output rows
224  * @param K Number of columns (input dimension)
225  */
226 void gemv_q5_1_backward_ref(float *dX,
227  const void *W,
228  const float *dY,
229  int M, int K)
230 {
231  const block_q5_1 *blocks = (const block_q5_1 *)W;
232  const int blocks_per_row = K / QK5_1;
233 
234  /* Zero output gradient */
235  memset(dX, 0, K * sizeof(float));
236 
237  /* Accumulate: dX += W^T @ dY */
238  for (int row = 0; row < M; row++) {
239  const float dy = dY[row];
240 
241  for (int b = 0; b < blocks_per_row; b++) {
242  const block_q5_1 *block = &blocks[row * blocks_per_row + b];
243  const float d = CK_FP16_TO_FP32(block->d);
244  const float m = CK_FP16_TO_FP32(block->m);
245  float *dxp = &dX[b * QK5_1];
246 
247  /* Get high bits */
248  uint32_t qh;
249  memcpy(&qh, block->qh, sizeof(qh));
250 
251  /* First 16 weights: low nibbles, high bits from qh[0:15] */
252  for (int j = 0; j < QK5_1 / 2; j++) {
253  const int lo = (block->qs[j] & 0x0F);
254  const int hi = ((qh >> j) & 1) << 4;
255  const float w = d * (float)(lo | hi) + m;
256  dxp[j] += w * dy;
257  }
258 
259  /* Second 16 weights: high nibbles, high bits from qh[16:31] */
260  for (int j = 0; j < QK5_1 / 2; j++) {
261  const int lo = (block->qs[j] >> 4);
262  const int hi = ((qh >> (j + 16)) & 1) << 4;
263  const float w = d * (float)(lo | hi) + m;
264  dxp[j + QK5_1 / 2] += w * dy;
265  }
266  }
267  }
268 }
269 
270 /**
271  * @brief Auto-dispatch backward
272  */
273 void gemv_q5_1_backward(float *dX,
274  const void *W,
275  const float *dY,
276  int M, int K)
277 {
278  gemv_q5_1_backward_ref(dX, W, dY, M, K);
279 }
280 
281 /**
282  * @brief Batched backward pass
283  */
284 void gemm_q5_1_backward(float *dX,
285  const void *W,
286  const float *dY,
287  int M, int N, int K)
288 {
289  for (int n = 0; n < N; n++) {
290  gemv_q5_1_backward(&dX[n * K], W, &dY[n * M], M, K);
291  }
292 }
293 
294 /* ============================================================================
295  * GEMM NT (Non-Transpose A, Transpose B) - C = A @ B^T
296  * ============================================================================ */
297 
298 /**
299  * @brief GEMM with transposed Q5_1 weights: C = A @ B^T
300  *
301  * @param A Input activations [M x K], row-major FP32
302  * @param B Weight matrix in Q5_1 format [N x K], row-major quantized
303  * @param bias Optional bias [N], NULL if not used
304  * @param C Output [M x N], row-major FP32
305  * @param M Batch size (number of tokens)
306  * @param N Output dimension
307  * @param K Input dimension
308  */
309 void gemm_nt_q5_1(const float *A,
310  const void *B,
311  const float *bias,
312  float *C,
313  int M, int N, int K)
314 {
315  const block_q5_1 *blocks = (const block_q5_1 *)B;
316  const int blocks_per_row = K / QK5_1;
317 
318  for (int m = 0; m < M; m++) {
319  const float *a_row = &A[m * K];
320 
321  for (int n = 0; n < N; n++) {
322  float sum = 0.0f;
323 
324  for (int b = 0; b < blocks_per_row; b++) {
325  const block_q5_1 *block = &blocks[n * blocks_per_row + b];
326  const float d = CK_FP16_TO_FP32(block->d);
327  const float min = CK_FP16_TO_FP32(block->m);
328  const float *ap = &a_row[b * QK5_1];
329 
330  uint32_t qh;
331  memcpy(&qh, block->qh, sizeof(qh));
332 
333  /* First 16 weights: low nibbles, high bits from qh[0:15] */
334  for (int j = 0; j < QK5_1 / 2; j++) {
335  const int lo = (block->qs[j] & 0x0F);
336  const int hi = ((qh >> j) & 1) << 4;
337  sum += (d * (float)(lo | hi) + min) * ap[j];
338  }
339 
340  /* Second 16 weights: high nibbles, high bits from qh[16:31] */
341  for (int j = 0; j < QK5_1 / 2; j++) {
342  const int lo = (block->qs[j] >> 4);
343  const int hi = ((qh >> (j + 16)) & 1) << 4;
344  sum += (d * (float)(lo | hi) + min) * ap[j + QK5_1 / 2];
345  }
346  }
347 
348  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
349  }
350  }
351 }
352 
353 /* ============================================================================
354  * Dot Product Utility
355  * ============================================================================ */
356 
357 float dot_q5_1(const void *w_q5_1, const float *x, int K)
358 {
359  float result;
360  gemv_q5_1(&result, w_q5_1, x, 1, K);
361  return result;
362 }
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define QK5_1
Definition: ckernel_quant.h:84
void gemm_q5_1_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
float dot_q5_1(const void *w_q5_1, const float *x, int K)
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemv_q5_1_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_1 weights (scalar reference)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void gemv_q5_1_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void gemv_q5_1_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemm_q5_1(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q5_1 weights.
#define C(color)
Definition: show_config.c:39
uint8_t qs[32/2]
Definition: ckernel_quant.h:90
uint8_t qh[4]
Definition: ckernel_quant.h:89
ck_half m
Definition: ckernel_quant.h:88
ck_half d
Definition: ckernel_quant.h:87