← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4_1.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q4_1.c
3  * @brief GEMM/GEMV kernels with Q4_1 quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Q4_1 Format:
15  * - 32 weights per block
16  * - 1 FP16 scale (d) per block
17  * - 1 FP16 minimum (m) per block
18  * - 20 bytes per 32 weights = 5.0 bits/weight
19  *
20  * Dequantization: w = d * q + m
21  * where q is the 4-bit unsigned value (0-15)
22  *
23  * Operations:
24  * Forward: Y = W @ X (W is Q4_1, X and Y are FP32)
25  * Backward: dX = W^T @ dY (gradient w.r.t. input)
26  */
27 
28 #include <stdint.h>
29 #include <stddef.h>
30 #include <string.h>
31 #include "ckernel_quant.h"
32 
33 #ifdef __AVX512F__
34 #include <immintrin.h>
35 #endif
36 
37 /* ============================================================================
38  * Forward Pass: GEMV y = W @ x
39  * ============================================================================ */
40 
41 /**
42  * @brief Matrix-vector multiply with Q4_1 weights (scalar reference)
43  *
44  * @param y Output vector [M]
45  * @param W Weight matrix in Q4_1 format [M x K]
46  * @param x Input vector [K]
47  * @param M Number of output rows
48  * @param K Number of columns (must be multiple of 32)
49  */
50 void gemv_q4_1_ref(float *y,
51  const void *W,
52  const float *x,
53  int M, int K)
54 {
55  const block_q4_1 *blocks = (const block_q4_1 *)W;
56  const int blocks_per_row = K / QK4_1;
57 
58  for (int row = 0; row < M; row++) {
59  float sum = 0.0f;
60 
61  for (int b = 0; b < blocks_per_row; b++) {
62  const block_q4_1 *block = &blocks[row * blocks_per_row + b];
63  const float d = CK_FP16_TO_FP32(block->d);
64  const float m = CK_FP16_TO_FP32(block->m);
65  const float *xp = &x[b * QK4_1];
66 
67  for (int i = 0; i < QK4_1 / 2; i++) {
68  const uint8_t packed = block->qs[i];
69  const int q0 = (packed & 0x0F);
70  const int q1 = (packed >> 4);
71 
72  /* Dequantize: w = d * q + m */
73  const float w0 = d * (float)q0 + m;
74  const float w1 = d * (float)q1 + m;
75 
76  sum += w0 * xp[2*i + 0];
77  sum += w1 * xp[2*i + 1];
78  }
79  }
80 
81  y[row] = sum;
82  }
83 }
84 
85 #ifdef __AVX512F__
86 /**
87  * @brief Matrix-vector multiply with Q4_1 weights (AVX-512)
88  */
89 void gemv_q4_1_avx512(float *y,
90  const void *W,
91  const float *x,
92  int M, int K)
93 {
94  const block_q4_1 *blocks = (const block_q4_1 *)W;
95  const int blocks_per_row = K / QK4_1;
96  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
97 
98  for (int row = 0; row < M; row++) {
99  __m512 acc = _mm512_setzero_ps();
100 
101  for (int b = 0; b < blocks_per_row; b++) {
102  const block_q4_1 *block = &blocks[row * blocks_per_row + b];
103  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
104  const __m512 vmin = _mm512_set1_ps(CK_FP16_TO_FP32(block->m));
105  const float *xp = &x[b * QK4_1];
106 
107  /* Load 16 bytes = 32 x 4-bit weights */
108  __m128i packed = _mm_loadu_si128((const __m128i *)block->qs);
109  __m512i bytes = _mm512_cvtepu8_epi32(packed);
110 
111  /* Extract nibbles */
112  __m512i lo = _mm512_and_epi32(bytes, mask_lo);
113  __m512i hi = _mm512_srli_epi32(bytes, 4);
114 
115  /* Dequantize: w = d * q + m */
116  __m512 w_lo = _mm512_fmadd_ps(_mm512_cvtepi32_ps(lo), vscale, vmin);
117  __m512 w_hi = _mm512_fmadd_ps(_mm512_cvtepi32_ps(hi), vscale, vmin);
118 
119  /* Load interleaved input */
120  __m512 x_even = _mm512_set_ps(
121  xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
122  xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
123  __m512 x_odd = _mm512_set_ps(
124  xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
125  xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
126 
127  acc = _mm512_fmadd_ps(w_lo, x_even, acc);
128  acc = _mm512_fmadd_ps(w_hi, x_odd, acc);
129  }
130 
131  y[row] = _mm512_reduce_add_ps(acc);
132  }
133 }
134 #endif
135 
136 /**
137  * @brief Auto-dispatch GEMV
138  */
139 void gemv_q4_1(float *y,
140  const void *W,
141  const float *x,
142  int M, int K)
143 {
144 #ifdef __AVX512F__
145  gemv_q4_1_avx512(y, W, x, M, K);
146 #else
147  gemv_q4_1_ref(y, W, x, M, K);
148 #endif
149 }
150 
151 /* ============================================================================
152  * Forward Pass: GEMM Y = W @ X
153  * ============================================================================ */
154 
155 /**
156  * @brief Matrix-matrix multiply with Q4_1 weights
157  */
158 void gemm_q4_1(float *Y,
159  const void *W,
160  const float *X,
161  int M, int N, int K)
162 {
163  for (int n = 0; n < N; n++) {
164  gemv_q4_1(&Y[n * M], W, &X[n * K], M, K);
165  }
166 }
167 
168 /* ============================================================================
169  * Backward Pass: Gradient w.r.t. Input
170  * ============================================================================ */
171 
172 /**
173  * @brief Backward pass: compute input gradient
174  *
175  * @param dX Output gradient w.r.t. input [K]
176  * @param W Weight matrix in Q4_1 format [M x K]
177  * @param dY Gradient w.r.t. output [M]
178  * @param M Number of output rows
179  * @param K Number of columns (input dimension)
180  */
181 void gemv_q4_1_backward_ref(float *dX,
182  const void *W,
183  const float *dY,
184  int M, int K)
185 {
186  const block_q4_1 *blocks = (const block_q4_1 *)W;
187  const int blocks_per_row = K / QK4_1;
188 
189  /* Zero output gradient */
190  memset(dX, 0, K * sizeof(float));
191 
192  /* Accumulate: dX += W^T @ dY */
193  for (int row = 0; row < M; row++) {
194  const float dy = dY[row];
195 
196  for (int b = 0; b < blocks_per_row; b++) {
197  const block_q4_1 *block = &blocks[row * blocks_per_row + b];
198  const float d = CK_FP16_TO_FP32(block->d);
199  const float m = CK_FP16_TO_FP32(block->m);
200  float *dxp = &dX[b * QK4_1];
201 
202  for (int i = 0; i < QK4_1 / 2; i++) {
203  const uint8_t packed = block->qs[i];
204  const int q0 = (packed & 0x0F);
205  const int q1 = (packed >> 4);
206 
207  const float w0 = d * (float)q0 + m;
208  const float w1 = d * (float)q1 + m;
209 
210  dxp[2*i + 0] += w0 * dy;
211  dxp[2*i + 1] += w1 * dy;
212  }
213  }
214  }
215 }
216 
217 /**
218  * @brief Auto-dispatch backward
219  */
220 void gemv_q4_1_backward(float *dX,
221  const void *W,
222  const float *dY,
223  int M, int K)
224 {
225  gemv_q4_1_backward_ref(dX, W, dY, M, K);
226 }
227 
228 /**
229  * @brief Batched backward pass
230  */
231 void gemm_q4_1_backward(float *dX,
232  const void *W,
233  const float *dY,
234  int M, int N, int K)
235 {
236  for (int n = 0; n < N; n++) {
237  gemv_q4_1_backward(&dX[n * K], W, &dY[n * M], M, K);
238  }
239 }
240 
241 /* ============================================================================
242  * GEMM NT (Non-Transpose A, Transpose B) - C = A @ B^T
243  * ============================================================================ */
244 
245 /**
246  * @brief GEMM with transposed Q4_1 weights: C = A @ B^T
247  *
248  * @param A Input activations [M x K], row-major FP32
249  * @param B Weight matrix in Q4_1 format [N x K], row-major quantized
250  * @param bias Optional bias [N], NULL if not used
251  * @param C Output [M x N], row-major FP32
252  * @param M Batch size (number of tokens)
253  * @param N Output dimension
254  * @param K Input dimension
255  */
256 void gemm_nt_q4_1(const float *A,
257  const void *B,
258  const float *bias,
259  float *C,
260  int M, int N, int K)
261 {
262  const block_q4_1 *blocks = (const block_q4_1 *)B;
263  const int blocks_per_row = K / QK4_1;
264 
265  for (int m = 0; m < M; m++) {
266  const float *a_row = &A[m * K];
267 
268  for (int n = 0; n < N; n++) {
269  float sum = 0.0f;
270 
271  for (int b = 0; b < blocks_per_row; b++) {
272  const block_q4_1 *block = &blocks[n * blocks_per_row + b];
273  const float d = CK_FP16_TO_FP32(block->d);
274  const float min = CK_FP16_TO_FP32(block->m);
275  const float *ap = &a_row[b * QK4_1];
276 
277  for (int i = 0; i < QK4_1 / 2; i++) {
278  const uint8_t packed = block->qs[i];
279  const int q0 = (packed & 0x0F);
280  const int q1 = (packed >> 4);
281 
282  const float w0 = d * (float)q0 + min;
283  const float w1 = d * (float)q1 + min;
284 
285  sum += w0 * ap[2 * i + 0];
286  sum += w1 * ap[2 * i + 1];
287  }
288  }
289 
290  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
291  }
292  }
293 }
294 
295 /* ============================================================================
296  * Dot Product Utility
297  * ============================================================================ */
298 
299 float dot_q4_1(const void *w_q4_1, const float *x, int K)
300 {
301  float result;
302  gemv_q4_1(&result, w_q4_1, x, 1, K);
303  return result;
304 }
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define QK4_1
Definition: ckernel_quant.h:50
void gemm_q4_1_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemv_q4_1_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void gemv_q4_1_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q4_1_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_1 weights (scalar reference)
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void gemv_q4_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemm_q4_1(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_1 weights.
float dot_q4_1(const void *w_q4_1, const float *x, int K)
#define C(color)
Definition: show_config.c:39
ck_half m
Definition: ckernel_quant.h:54
ck_half d
Definition: ckernel_quant.h:53
uint8_t qs[32/2]
Definition: ckernel_quant.h:55