← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4_0.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q4_0.c
3  * @brief GEMM/GEMV kernels with Q4_0 quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Q4_0 Format:
15  * - 32 weights per block
16  * - 1 FP16 scale per block
17  * - 18 bytes per 32 weights = 4.5 bits/weight
18  *
19  * Operations:
20  * Forward: Y = W @ X (W is Q4_0, X and Y are FP32)
21  * Backward: dX = W^T @ dY (gradient w.r.t. input)
22  *
23  * Note: Weight gradients are not computed for quantized weights.
24  * For fine-tuning, use LoRA adapters which maintain FP32 gradients separately.
25  */
26 
27 #include <stdint.h>
28 #include <stddef.h>
29 #include <string.h>
30 #include "ckernel_quant.h"
31 
32 #ifdef __AVX512F__
33 #include <immintrin.h>
34 #endif
35 
36 /* ============================================================================
37  * Forward Pass: GEMV y = W @ x
38  * ============================================================================ */
39 
40 /**
41  * @brief Matrix-vector multiply with Q4_0 weights (scalar reference)
42  *
43  * @param y Output vector [M]
44  * @param W Weight matrix in Q4_0 format [M x K]
45  * @param x Input vector [K]
46  * @param M Number of output rows
47  * @param K Number of columns (must be multiple of 32)
48  */
49 void gemv_q4_0_ref(float *y,
50  const void *W,
51  const float *x,
52  int M, int K)
53 {
54  const block_q4_0 *blocks = (const block_q4_0 *)W;
55  const int blocks_per_row = K / QK4_0;
56 
57  for (int row = 0; row < M; row++) {
58  float sum = 0.0f;
59 
60  for (int b = 0; b < blocks_per_row; b++) {
61  const block_q4_0 *block = &blocks[row * blocks_per_row + b];
62  const float d = CK_FP16_TO_FP32(block->d);
63  const float *xp = &x[b * QK4_0];
64 
65  for (int i = 0; i < QK4_0 / 2; i++) {
66  const uint8_t packed = block->qs[i];
67  const int8_t q0 = (packed & 0x0F) - 8;
68  const int8_t q1 = (packed >> 4) - 8;
69 
70  sum += d * (float)q0 * xp[2*i + 0];
71  sum += d * (float)q1 * xp[2*i + 1];
72  }
73  }
74 
75  y[row] = sum;
76  }
77 }
78 
79 #ifdef __AVX512F__
80 /**
81  * @brief Matrix-vector multiply with Q4_0 weights (AVX-512)
82  */
83 void gemv_q4_0_avx512(float *y,
84  const void *W,
85  const float *x,
86  int M, int K)
87 {
88  const block_q4_0 *blocks = (const block_q4_0 *)W;
89  const int blocks_per_row = K / QK4_0;
90  const __m512i offset = _mm512_set1_epi32(8);
91  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
92 
93  for (int row = 0; row < M; row++) {
94  __m512 acc = _mm512_setzero_ps();
95 
96  for (int b = 0; b < blocks_per_row; b++) {
97  const block_q4_0 *block = &blocks[row * blocks_per_row + b];
98  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
99  const float *xp = &x[b * QK4_0];
100 
101  /* Load 16 bytes = 32 x 4-bit weights */
102  __m128i packed = _mm_loadu_si128((const __m128i *)block->qs);
103  __m512i bytes = _mm512_cvtepu8_epi32(packed);
104 
105  /* Extract and dequantize */
106  __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
107  __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
108 
109  __m512 w_lo = _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vscale);
110  __m512 w_hi = _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vscale);
111 
112  /* Load interleaved input */
113  __m512 x_even = _mm512_set_ps(
114  xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
115  xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
116  __m512 x_odd = _mm512_set_ps(
117  xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
118  xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
119 
120  acc = _mm512_fmadd_ps(w_lo, x_even, acc);
121  acc = _mm512_fmadd_ps(w_hi, x_odd, acc);
122  }
123 
124  y[row] = _mm512_reduce_add_ps(acc);
125  }
126 }
127 #endif
128 
129 /**
130  * @brief Auto-dispatch GEMV
131  */
132 void gemv_q4_0(float *y,
133  const void *W,
134  const float *x,
135  int M, int K)
136 {
137 #ifdef __AVX512F__
138  gemv_q4_0_avx512(y, W, x, M, K);
139 #else
140  gemv_q4_0_ref(y, W, x, M, K);
141 #endif
142 }
143 
144 /* ============================================================================
145  * Forward Pass: GEMM Y = W @ X
146  * ============================================================================ */
147 
148 /**
149  * @brief Matrix-matrix multiply with Q4_0 weights
150  */
151 void gemm_q4_0(float *Y,
152  const void *W,
153  const float *X,
154  int M, int N, int K)
155 {
156  for (int n = 0; n < N; n++) {
157  gemv_q4_0(&Y[n * M], W, &X[n * K], M, K);
158  }
159 }
160 
161 /* ============================================================================
162  * GEMM NT: C = A @ B^T + bias (B stored as N rows of K elements)
163  * ============================================================================ */
164 
165 /**
166  * @brief Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias
167  *
168  * @param A Input matrix [M x K], row-major FP32
169  * @param B Weight matrix in Q4_0 format, [N x K] stored row-major
170  * @param bias Optional bias [N], NULL if not used
171  * @param C Output [M x N], row-major FP32
172  * @param M Batch size (number of tokens)
173  * @param N Output dimension (number of rows in B)
174  * @param K Input dimension
175  */
176 void gemm_nt_q4_0(const float *A,
177  const void *B,
178  const float *bias,
179  float *C,
180  int M, int N, int K)
181 {
182  const block_q4_0 *blocks = (const block_q4_0 *)B;
183  const int blocks_per_row = K / QK4_0;
184 
185  for (int m = 0; m < M; m++) {
186  const float *a_row = &A[m * K];
187 
188  for (int n = 0; n < N; n++) {
189  float sum = 0.0f;
190 
191  for (int b = 0; b < blocks_per_row; b++) {
192  const block_q4_0 *block = &blocks[n * blocks_per_row + b];
193  const float d = CK_FP16_TO_FP32(block->d);
194  const float *ap = &a_row[b * QK4_0];
195 
196  for (int i = 0; i < QK4_0 / 2; i++) {
197  const uint8_t packed = block->qs[i];
198  const int q0 = (packed & 0x0F) - 8;
199  const int q1 = (packed >> 4) - 8;
200 
201  sum += d * (float)q0 * ap[2 * i + 0];
202  sum += d * (float)q1 * ap[2 * i + 1];
203  }
204  }
205 
206  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
207  }
208  }
209 }
210 
211 /* ============================================================================
212  * Backward Pass: Gradient w.r.t. Input
213  *
214  * Given: dL/dY (gradient of loss w.r.t. output)
215  * Compute: dL/dX = W^T @ dL/dY
216  *
217  * For quantized weights, we dequantize on-the-fly during backprop.
218  * Weight gradients are NOT computed (weights are frozen).
219  * ============================================================================ */
220 
221 /**
222  * @brief Backward pass: compute input gradient
223  *
224  * @param dX Output gradient w.r.t. input [K]
225  * @param W Weight matrix in Q4_0 format [M x K]
226  * @param dY Gradient w.r.t. output [M]
227  * @param M Number of output rows
228  * @param K Number of columns (input dimension)
229  */
230 void gemv_q4_0_backward_ref(float *dX,
231  const void *W,
232  const float *dY,
233  int M, int K)
234 {
235  const block_q4_0 *blocks = (const block_q4_0 *)W;
236  const int blocks_per_row = K / QK4_0;
237 
238  /* Zero output gradient */
239  memset(dX, 0, K * sizeof(float));
240 
241  /* Accumulate: dX += W^T @ dY */
242  for (int row = 0; row < M; row++) {
243  const float dy = dY[row];
244 
245  for (int b = 0; b < blocks_per_row; b++) {
246  const block_q4_0 *block = &blocks[row * blocks_per_row + b];
247  const float d = CK_FP16_TO_FP32(block->d);
248  float *dxp = &dX[b * QK4_0];
249 
250  for (int i = 0; i < QK4_0 / 2; i++) {
251  const uint8_t packed = block->qs[i];
252  const int8_t q0 = (packed & 0x0F) - 8;
253  const int8_t q1 = (packed >> 4) - 8;
254 
255  dxp[2*i + 0] += d * (float)q0 * dy;
256  dxp[2*i + 1] += d * (float)q1 * dy;
257  }
258  }
259  }
260 }
261 
262 #ifdef __AVX512F__
263 /**
264  * @brief Backward pass with AVX-512
265  */
266 void gemv_q4_0_backward_avx512(float *dX,
267  const void *W,
268  const float *dY,
269  int M, int K)
270 {
271  const block_q4_0 *blocks = (const block_q4_0 *)W;
272  const int blocks_per_row = K / QK4_0;
273  const __m512i offset = _mm512_set1_epi32(8);
274  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
275 
276  /* Zero output */
277  memset(dX, 0, K * sizeof(float));
278 
279  for (int row = 0; row < M; row++) {
280  const __m512 vdy = _mm512_set1_ps(dY[row]);
281 
282  for (int b = 0; b < blocks_per_row; b++) {
283  const block_q4_0 *block = &blocks[row * blocks_per_row + b];
284  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
285  float *dxp = &dX[b * QK4_0];
286 
287  /* Dequantize weights */
288  __m128i packed = _mm_loadu_si128((const __m128i *)block->qs);
289  __m512i bytes = _mm512_cvtepu8_epi32(packed);
290 
291  __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
292  __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
293 
294  __m512 w_lo = _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vscale);
295  __m512 w_hi = _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vscale);
296 
297  /* Compute gradients: dX += W * dY */
298  __m512 grad_lo = _mm512_mul_ps(w_lo, vdy);
299  __m512 grad_hi = _mm512_mul_ps(w_hi, vdy);
300 
301  /* Scatter to interleaved positions (simplified - actual impl needs gather/scatter) */
302  float grad_lo_arr[16], grad_hi_arr[16];
303  _mm512_storeu_ps(grad_lo_arr, grad_lo);
304  _mm512_storeu_ps(grad_hi_arr, grad_hi);
305 
306  for (int i = 0; i < 16; i++) {
307  dxp[2*i + 0] += grad_lo_arr[i];
308  dxp[2*i + 1] += grad_hi_arr[i];
309  }
310  }
311  }
312 }
313 #endif
314 
315 /**
316  * @brief Auto-dispatch backward
317  */
318 void gemv_q4_0_backward(float *dX,
319  const void *W,
320  const float *dY,
321  int M, int K)
322 {
323 #ifdef __AVX512F__
324  gemv_q4_0_backward_avx512(dX, W, dY, M, K);
325 #else
326  gemv_q4_0_backward_ref(dX, W, dY, M, K);
327 #endif
328 }
329 
330 /**
331  * @brief Batched backward pass
332  */
333 void gemm_q4_0_backward(float *dX,
334  const void *W,
335  const float *dY,
336  int M, int N, int K)
337 {
338  for (int n = 0; n < N; n++) {
339  gemv_q4_0_backward(&dX[n * K], W, &dY[n * M], M, K);
340  }
341 }
342 
343 /* ============================================================================
344  * Dot Product Utility
345  * ============================================================================ */
346 
347 float dot_q4_0(const void *w_q4_0, const float *x, int K)
348 {
349  float result;
350  gemv_q4_0(&result, w_q4_0, x, 1, K);
351  return result;
352 }
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define QK4_0
Definition: ckernel_quant.h:35
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemv_q4_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemm_q4_0_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
float dot_q4_0(const void *w_q4_0, const float *x, int K)
void gemm_q4_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_0 weights.
void gemv_q4_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void gemv_q4_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_0 weights (scalar reference)
void gemv_q4_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
#define C(color)
Definition: show_config.c:39
ck_half d
Definition: ckernel_quant.h:38
uint8_t qs[32/2]
Definition: ckernel_quant.h:39