← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q4k.c
3  * @brief GEMM/GEMV kernels with Q4_K quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Implements matrix multiplication where:
15  * - Activations (input): FP32
16  * - Weights: Q4_K (4.5 bits/weight, nested scales)
17  * - Output: FP32
18  *
19  * Key optimization: Fused dequantization - weights are dequantized in
20  * registers and immediately used in FMA, never written to memory.
21  *
22  * Operations:
23  * - gemv_q4_k: Matrix-vector multiply (batch=1, token generation)
24  * - gemm_q4_k: Matrix-matrix multiply (batch>1, prefill)
25  */
26 
27 #include <stdint.h>
28 #include <stddef.h>
29 #include <string.h>
30 #include "ckernel_quant.h"
31 
32 /* Include SIMD headers based on available extensions */
33 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
34 #include <immintrin.h>
35 #endif
36 
37 /* ============================================================================
38  * GEMV: y = W @ x (W is Q4_K, x and y are FP32)
39  *
40  * For token generation (batch=1), this is the critical path.
41  * Memory-bound: we're loading ~4GB of weights for a 7B model per token.
42  * ============================================================================ */
43 
44 /**
45  * @brief Matrix-vector multiply with Q4_K weights (scalar reference)
46  *
47  * @param y Output vector [M]
48  * @param W Weight matrix in Q4_K format [M x K], stored row-major
49  * @param x Input vector [K]
50  * @param M Number of output rows
51  * @param K Number of columns (must be multiple of 256)
52  */
53 void gemv_q4_k_ref(float *y,
54  const void *W,
55  const float *x,
56  int M, int K)
57 {
58  const block_q4_K *blocks = (const block_q4_K *)W;
59  const int blocks_per_row = K / QK_K; /* QK_K = 256 */
60 
61  for (int row = 0; row < M; row++) {
62  float sum = 0.0f;
63 
64  for (int b = 0; b < blocks_per_row; b++) {
65  const block_q4_K *block = &blocks[row * blocks_per_row + b];
66  const float d = GGML_FP16_TO_FP32(block->d);
67  const float dmin = GGML_FP16_TO_FP32(block->dmin);
68 
69  /* Unpack sub-block scales */
70  uint8_t sc[8], m[8];
71  unpack_q4_k_scales(block->scales, sc, m);
72 
73  /* llama.cpp Q4_K layout: 4 iterations of 64 weights each
74  * Each iteration uses 32 bytes of qs and 2 scales:
75  * - First 32 weights (indices 0-31): low nibbles with scale[2*iter]
76  * - Next 32 weights (indices 32-63): high nibbles with scale[2*iter+1]
77  */
78  for (int iter = 0; iter < 4; iter++) {
79  const float d1 = d * (float)sc[2*iter];
80  const float m1 = dmin * (float)m[2*iter];
81  const float d2 = d * (float)sc[2*iter + 1];
82  const float m2 = dmin * (float)m[2*iter + 1];
83  const uint8_t *qs = &block->qs[iter * 32];
84  const float *xp = &x[b * QK_K + iter * 64];
85 
86  /* First 32 weights: low nibbles of qs[0..31] */
87  for (int l = 0; l < 32; l++) {
88  const int8_t q = (qs[l] & 0x0F);
89  sum += (d1 * (float)q - m1) * xp[l];
90  }
91  /* Next 32 weights: high nibbles of qs[0..31] */
92  for (int l = 0; l < 32; l++) {
93  const int8_t q = (qs[l] >> 4);
94  sum += (d2 * (float)q - m2) * xp[l + 32];
95  }
96  }
97  }
98 
99  y[row] = sum;
100  }
101 }
102 
103 #ifdef __AVX512F__
104 /**
105  * @brief Matrix-vector multiply with Q4_K weights (AVX-512 optimized)
106  *
107  * Fused dequant + FMA: weights dequantized in ZMM registers, never touch RAM.
108  */
109 void gemv_q4_k_avx512(float *y,
110  const void *W,
111  const float *x,
112  int M, int K)
113 {
114  const block_q4_K *blocks = (const block_q4_K *)W;
115  const int blocks_per_row = K / QK_K;
116 
117  for (int row = 0; row < M; row++) {
118  __m512 acc = _mm512_setzero_ps();
119 
120  for (int b = 0; b < blocks_per_row; b++) {
121  const block_q4_K *block = &blocks[row * blocks_per_row + b];
122  const float d = GGML_FP16_TO_FP32(block->d);
123  const float dmin = GGML_FP16_TO_FP32(block->dmin);
124 
125  uint8_t sc[8], m_arr[8];
126  unpack_q4_k_scales(block->scales, sc, m_arr);
127 
128  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
129 
130  /* llama.cpp Q4_K layout: 4 iterations of 64 weights each
131  * Formula: w = d * q - m (NOT d * (q-8) + m)
132  */
133  for (int iter = 0; iter < 4; iter++) {
134  const float d1 = d * (float)sc[2*iter];
135  const float m1 = dmin * (float)m_arr[2*iter];
136  const float d2 = d * (float)sc[2*iter + 1];
137  const float m2 = dmin * (float)m_arr[2*iter + 1];
138 
139  const __m512 vscale1 = _mm512_set1_ps(d1);
140  const __m512 vmin1 = _mm512_set1_ps(m1);
141  const __m512 vscale2 = _mm512_set1_ps(d2);
142  const __m512 vmin2 = _mm512_set1_ps(m2);
143 
144  const uint8_t *qs = &block->qs[iter * 32];
145  const float *xp = &x[b * QK_K + iter * 64];
146 
147  /* Process first 32 weights (low nibbles) */
148  /* Load 16 bytes at a time */
149  for (int chunk = 0; chunk < 2; chunk++) {
150  __m128i packed = _mm_loadu_si128((const __m128i *)&qs[chunk * 16]);
151  __m512i bytes = _mm512_cvtepu8_epi32(packed);
152  __m512i lo = _mm512_and_epi32(bytes, mask_lo);
153  /* w = d * q - m: use fnmadd (negative m) */
154  __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vmin1,
155  _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vscale1));
156  __m512 x_vec = _mm512_loadu_ps(&xp[chunk * 16]);
157  acc = _mm512_fmadd_ps(w, x_vec, acc);
158  }
159 
160  /* Process next 32 weights (high nibbles) */
161  for (int chunk = 0; chunk < 2; chunk++) {
162  __m128i packed = _mm_loadu_si128((const __m128i *)&qs[chunk * 16]);
163  __m512i bytes = _mm512_cvtepu8_epi32(packed);
164  __m512i hi = _mm512_srli_epi32(bytes, 4);
165  /* w = d * q - m: use fnmadd (negative m) */
166  __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vmin2,
167  _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vscale2));
168  __m512 x_vec = _mm512_loadu_ps(&xp[32 + chunk * 16]);
169  acc = _mm512_fmadd_ps(w, x_vec, acc);
170  }
171  }
172  }
173 
174  /* Horizontal sum */
175  y[row] = _mm512_reduce_add_ps(acc);
176  }
177 }
178 #endif /* __AVX512F__ */
179 
180 /* ============================================================================
181  * AVX Implementation (256-bit, works on Sandy Bridge and later)
182  *
183  * This is critical for CPUs that have AVX but not AVX-512.
184  * Processes 8 floats per iteration using 256-bit registers.
185  * NOTE: Uses separate mul+add (no FMA) for Ivy Bridge compatibility.
186  * ============================================================================ */
187 
188 #if defined(__AVX__) && !defined(__AVX512F__)
189 /**
190  * @brief Matrix-vector multiply with Q4_K weights (AVX optimized)
191  *
192  * Processes 8 floats at a time. No FMA required (works on Ivy Bridge).
193  * About 4-8x faster than scalar reference.
194  */
195 void gemv_q4_k_avx(float *y,
196  const void *W,
197  const float *x,
198  int M, int K)
199 {
200  const block_q4_K *blocks = (const block_q4_K *)W;
201  const int blocks_per_row = K / QK_K; /* QK_K = 256 */
202 
203  for (int row = 0; row < M; row++) {
204  /* Use 4 accumulators for better instruction-level parallelism */
205  __m256 acc0 = _mm256_setzero_ps();
206  __m256 acc1 = _mm256_setzero_ps();
207  __m256 acc2 = _mm256_setzero_ps();
208  __m256 acc3 = _mm256_setzero_ps();
209 
210  for (int b = 0; b < blocks_per_row; b++) {
211  const block_q4_K *block = &blocks[row * blocks_per_row + b];
212  const float d = GGML_FP16_TO_FP32(block->d);
213  const float dmin = GGML_FP16_TO_FP32(block->dmin);
214 
215  /* Unpack sub-block scales */
216  uint8_t sc[8], m_arr[8];
217  unpack_q4_k_scales(block->scales, sc, m_arr);
218 
219  /* Process 256 weights in 4 iterations of 64 weights each */
220  for (int iter = 0; iter < 4; iter++) {
221  const float d1 = d * (float)sc[2*iter];
222  const float m1 = dmin * (float)m_arr[2*iter];
223  const float d2 = d * (float)sc[2*iter + 1];
224  const float m2 = dmin * (float)m_arr[2*iter + 1];
225  const uint8_t *qs = &block->qs[iter * 32];
226  const float *xp = &x[b * QK_K + iter * 64];
227 
228  /* Broadcast scale and min values */
229  __m256 vd1 = _mm256_set1_ps(d1);
230  __m256 vm1 = _mm256_set1_ps(m1);
231  __m256 vd2 = _mm256_set1_ps(d2);
232  __m256 vm2 = _mm256_set1_ps(m2);
233 
234  /* Process first 32 weights (low nibbles) in 4 groups of 8 */
235  for (int g = 0; g < 4; g++) {
236  /* Dequantize 8 weights from low nibbles */
237  float dq[8];
238  for (int i = 0; i < 8; i++) {
239  dq[i] = d1 * (float)(qs[g*8 + i] & 0x0F) - m1;
240  }
241  __m256 vw = _mm256_loadu_ps(dq);
242  __m256 vx = _mm256_loadu_ps(&xp[g*8]);
243 
244  /* acc0 += vw * vx (using mul+add, no FMA needed) */
245  __m256 prod = _mm256_mul_ps(vw, vx);
246  acc0 = _mm256_add_ps(acc0, prod);
247  }
248 
249  /* Process next 32 weights (high nibbles) in 4 groups of 8 */
250  for (int g = 0; g < 4; g++) {
251  /* Dequantize 8 weights from high nibbles */
252  float dq[8];
253  for (int i = 0; i < 8; i++) {
254  dq[i] = d2 * (float)(qs[g*8 + i] >> 4) - m2;
255  }
256  __m256 vw = _mm256_loadu_ps(dq);
257  __m256 vx = _mm256_loadu_ps(&xp[32 + g*8]);
258 
259  __m256 prod = _mm256_mul_ps(vw, vx);
260  acc1 = _mm256_add_ps(acc1, prod);
261  }
262  }
263  }
264 
265  /* Combine accumulators */
266  __m256 sum01 = _mm256_add_ps(acc0, acc1);
267  __m256 sum23 = _mm256_add_ps(acc2, acc3);
268  __m256 sum = _mm256_add_ps(sum01, sum23);
269 
270  /* Horizontal sum of 8 floats */
271  __m128 hi = _mm256_extractf128_ps(sum, 1);
272  __m128 lo = _mm256_castps256_ps128(sum);
273  __m128 sum128 = _mm_add_ps(hi, lo);
274  sum128 = _mm_hadd_ps(sum128, sum128);
275  sum128 = _mm_hadd_ps(sum128, sum128);
276 
277  y[row] = _mm_cvtss_f32(sum128);
278  }
279 }
280 #endif /* __AVX__ && !__AVX512F__ */
281 
282 /**
283  * @brief Auto-dispatch GEMV based on available SIMD
284  */
285 void gemv_q4_k(float *y,
286  const void *W,
287  const float *x,
288  int M, int K)
289 {
290 #ifdef __AVX512F__
291  gemv_q4_k_avx512(y, W, x, M, K);
292 #elif defined(__AVX__)
293  gemv_q4_k_avx(y, W, x, M, K);
294 #else
295  gemv_q4_k_ref(y, W, x, M, K);
296 #endif
297 }
298 
299 /* ============================================================================
300  * GEMM: Y = W @ X (W is Q4_K, X and Y are FP32)
301  *
302  * For prefill (batch > 1), we can amortize weight loading across batch.
303  * More compute-bound than GEMV.
304  * ============================================================================ */
305 
306 /**
307  * @brief Matrix-matrix multiply with Q4_K weights (scalar reference)
308  *
309  * @param Y Output matrix [M x N]
310  * @param W Weight matrix in Q4_K format [M x K]
311  * @param X Input matrix [K x N] (column-major for cache efficiency)
312  * @param M Number of output rows
313  * @param N Batch size (number of columns)
314  * @param K Hidden dimension
315  */
316 void gemm_q4_k_ref(float *Y,
317  const void *W,
318  const float *X,
319  int M, int N, int K)
320 {
321  /* For each column in batch, use the dispatching gemv_q4_k
322  * which automatically selects AVX/AVX-512/scalar based on CPU */
323  for (int n = 0; n < N; n++) {
324  gemv_q4_k(&Y[n * M], W, &X[n * K], M, K);
325  }
326 }
327 
328 #ifdef __AVX512F__
329 /**
330  * @brief Matrix-matrix multiply with Q4_K weights (AVX-512)
331  *
332  * Processes multiple batch elements to improve weight reuse.
333  */
334 void gemm_q4_k_avx512(float *Y,
335  const void *W,
336  const float *X,
337  int M, int N, int K)
338 {
339  const block_q4_K *blocks = (const block_q4_K *)W;
340  const int blocks_per_row = K / QK_K;
341 
342  /* Process 4 batch elements at a time for better register utilization */
343  const int N4 = N / 4 * 4;
344 
345  for (int row = 0; row < M; row++) {
346  /* Batch of 4 */
347  for (int n = 0; n < N4; n += 4) {
348  __m512 acc0 = _mm512_setzero_ps();
349  __m512 acc1 = _mm512_setzero_ps();
350  __m512 acc2 = _mm512_setzero_ps();
351  __m512 acc3 = _mm512_setzero_ps();
352 
353  for (int b = 0; b < blocks_per_row; b++) {
354  const block_q4_K *block = &blocks[row * blocks_per_row + b];
355  const float d = GGML_FP16_TO_FP32(block->d);
356  const float dmin = GGML_FP16_TO_FP32(block->dmin);
357 
358  uint8_t sc[8], m_arr[8];
359  unpack_q4_k_scales(block->scales, sc, m_arr);
360 
361  for (int sub = 0; sub < 8; sub++) {
362  const float scale = d * (float)sc[sub];
363  const float min_val = dmin * (float)m_arr[sub];
364  const __m512 vscale = _mm512_set1_ps(scale);
365  const __m512 vmin = _mm512_set1_ps(min_val);
366  const __m512i offset = _mm512_set1_epi32(8);
367  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
368 
369  const uint8_t *qs = &block->qs[sub * 16];
370  const int x_offset = b * QK_K + sub * 32;
371 
372  /* Dequantize weights (same for all batch elements) */
373  __m128i packed = _mm_loadu_si128((const __m128i *)qs);
374  __m512i bytes = _mm512_cvtepu8_epi32(packed);
375 
376  __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
377  __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
378 
379  __m512 w_lo = _mm512_fmadd_ps(_mm512_cvtepi32_ps(lo), vscale, vmin);
380  __m512 w_hi = _mm512_fmadd_ps(_mm512_cvtepi32_ps(hi), vscale, vmin);
381 
382  /* Load inputs for 4 batch elements and accumulate */
383  /* (simplified - full impl would handle interleaving) */
384  for (int bn = 0; bn < 4; bn++) {
385  const float *xp = &X[(n + bn) * K + x_offset];
386 
387  __m512 x_even = _mm512_set_ps(
388  xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
389  xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
390  __m512 x_odd = _mm512_set_ps(
391  xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
392  xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
393 
394  __m512 *acc = (bn == 0) ? &acc0 : (bn == 1) ? &acc1 :
395  (bn == 2) ? &acc2 : &acc3;
396  *acc = _mm512_fmadd_ps(w_lo, x_even, *acc);
397  *acc = _mm512_fmadd_ps(w_hi, x_odd, *acc);
398  }
399  }
400  }
401 
402  Y[(n + 0) * M + row] = _mm512_reduce_add_ps(acc0);
403  Y[(n + 1) * M + row] = _mm512_reduce_add_ps(acc1);
404  Y[(n + 2) * M + row] = _mm512_reduce_add_ps(acc2);
405  Y[(n + 3) * M + row] = _mm512_reduce_add_ps(acc3);
406  }
407 
408  /* Remainder */
409  for (int n = N4; n < N; n++) {
410  __m512 acc = _mm512_setzero_ps();
411 
412  for (int b = 0; b < blocks_per_row; b++) {
413  const block_q4_K *block = &blocks[row * blocks_per_row + b];
414  const float d = GGML_FP16_TO_FP32(block->d);
415  const float dmin = GGML_FP16_TO_FP32(block->dmin);
416 
417  uint8_t sc[8], m_arr[8];
418  unpack_q4_k_scales(block->scales, sc, m_arr);
419 
420  for (int sub = 0; sub < 8; sub++) {
421  const float scale = d * (float)sc[sub];
422  const float min_val = dmin * (float)m_arr[sub];
423  const __m512 vscale = _mm512_set1_ps(scale);
424  const __m512 vmin = _mm512_set1_ps(min_val);
425  const __m512i offset = _mm512_set1_epi32(8);
426  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
427 
428  const uint8_t *qs = &block->qs[sub * 16];
429  const float *xp = &X[n * K + b * QK_K + sub * 32];
430 
431  __m128i packed = _mm_loadu_si128((const __m128i *)qs);
432  __m512i bytes = _mm512_cvtepu8_epi32(packed);
433 
434  __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
435  __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
436 
437  __m512 w_lo = _mm512_fmadd_ps(_mm512_cvtepi32_ps(lo), vscale, vmin);
438  __m512 w_hi = _mm512_fmadd_ps(_mm512_cvtepi32_ps(hi), vscale, vmin);
439 
440  __m512 x_even = _mm512_set_ps(
441  xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
442  xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
443  __m512 x_odd = _mm512_set_ps(
444  xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
445  xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
446 
447  acc = _mm512_fmadd_ps(w_lo, x_even, acc);
448  acc = _mm512_fmadd_ps(w_hi, x_odd, acc);
449  }
450  }
451 
452  Y[n * M + row] = _mm512_reduce_add_ps(acc);
453  }
454  }
455 }
456 #endif /* __AVX512F__ */
457 
458 /**
459  * @brief Auto-dispatch GEMM based on available SIMD
460  */
461 void gemm_q4_k(float *Y,
462  const void *W,
463  const float *X,
464  int M, int N, int K)
465 {
466  /* Use reference implementation for correctness
467  * TODO: Fix AVX-512 version to match llama.cpp layout */
468  gemm_q4_k_ref(Y, W, X, M, N, K);
469 }
470 
471 /* ============================================================================
472  * Dot Product: Single row dot product with Q4_K weights
473  * Used internally and for testing.
474  * ============================================================================ */
475 
476 /**
477  * @brief Compute dot product of Q4_K row with FP32 vector
478  *
479  * @param w_q4k Q4_K blocks for one row
480  * @param x FP32 input vector
481  * @param K Vector length (must be multiple of 256)
482  * @return Dot product result
483  */
484 float dot_q4_k(const void *w_q4k, const float *x, int K)
485 {
486  float result;
487  gemv_q4_k(&result, w_q4k, x, 1, K);
488  return result;
489 }
490 
491 /* ============================================================================
492  * Backward Pass: Gradient w.r.t. Input
493  *
494  * Given: dL/dY (gradient of loss w.r.t. output)
495  * Compute: dL/dX = W^T @ dL/dY
496  *
497  * For quantized weights, we dequantize on-the-fly during backprop.
498  * Weight gradients are NOT computed (weights are frozen).
499  * For fine-tuning, use LoRA adapters which maintain FP32 gradients separately.
500  * ============================================================================ */
501 
502 /**
503  * @brief Backward pass: compute input gradient (scalar reference)
504  *
505  * @param dX Output gradient w.r.t. input [K]
506  * @param W Weight matrix in Q4_K format [M x K]
507  * @param dY Gradient w.r.t. output [M]
508  * @param M Number of output rows
509  * @param K Number of columns (input dimension)
510  */
511 void gemv_q4_k_backward_ref(float *dX,
512  const void *W,
513  const float *dY,
514  int M, int K)
515 {
516  const block_q4_K *blocks = (const block_q4_K *)W;
517  const int blocks_per_row = K / QK_K;
518 
519  /* Zero output gradient */
520  memset(dX, 0, K * sizeof(float));
521 
522  /* Accumulate: dX += W^T @ dY
523  * Uses llama.cpp layout: 4 iterations of 64 weights each */
524  for (int row = 0; row < M; row++) {
525  const float dy = dY[row];
526 
527  for (int b = 0; b < blocks_per_row; b++) {
528  const block_q4_K *block = &blocks[row * blocks_per_row + b];
529  const float d = CK_FP16_TO_FP32(block->d);
530  const float dmin = CK_FP16_TO_FP32(block->dmin);
531 
532  uint8_t sc[8], m[8];
533  unpack_q4_k_scales(block->scales, sc, m);
534 
535  /* llama.cpp layout: 4 iterations of 64 weights each */
536  for (int iter = 0; iter < 4; iter++) {
537  const float d1 = d * (float)sc[2 * iter];
538  const float m1 = dmin * (float)m[2 * iter];
539  const float d2 = d * (float)sc[2 * iter + 1];
540  const float m2 = dmin * (float)m[2 * iter + 1];
541 
542  const uint8_t *qs = &block->qs[iter * 32];
543  float *dxp = &dX[b * QK_K + iter * 64];
544 
545  /* First 32 weights: low nibbles */
546  for (int l = 0; l < 32; l++) {
547  const int q = (qs[l] & 0x0F);
548  const float w = d1 * (float)q - m1;
549  dxp[l] += w * dy;
550  }
551 
552  /* Next 32 weights: high nibbles */
553  for (int l = 0; l < 32; l++) {
554  const int q = (qs[l] >> 4);
555  const float w = d2 * (float)q - m2;
556  dxp[32 + l] += w * dy;
557  }
558  }
559  }
560  }
561 }
562 
563 #ifdef __AVX512F__
564 /**
565  * @brief Backward pass with AVX-512
566  *
567  * Uses llama.cpp layout: 4 iterations of 64 weights each
568  */
569 void gemv_q4_k_backward_avx512(float *dX,
570  const void *W,
571  const float *dY,
572  int M, int K)
573 {
574  const block_q4_K *blocks = (const block_q4_K *)W;
575  const int blocks_per_row = K / QK_K;
576  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
577 
578  /* Zero output */
579  memset(dX, 0, K * sizeof(float));
580 
581  for (int row = 0; row < M; row++) {
582  const __m512 vdy = _mm512_set1_ps(dY[row]);
583 
584  for (int b = 0; b < blocks_per_row; b++) {
585  const block_q4_K *block = &blocks[row * blocks_per_row + b];
586  const float d = CK_FP16_TO_FP32(block->d);
587  const float dmin = CK_FP16_TO_FP32(block->dmin);
588 
589  uint8_t sc[8], m_arr[8];
590  unpack_q4_k_scales(block->scales, sc, m_arr);
591 
592  /* llama.cpp layout: 4 iterations of 64 weights each */
593  for (int iter = 0; iter < 4; iter++) {
594  const float d1 = d * (float)sc[2 * iter];
595  const float m1 = dmin * (float)m_arr[2 * iter];
596  const float d2 = d * (float)sc[2 * iter + 1];
597  const float m2 = dmin * (float)m_arr[2 * iter + 1];
598 
599  const __m512 vd1 = _mm512_set1_ps(d1);
600  const __m512 vm1 = _mm512_set1_ps(m1);
601  const __m512 vd2 = _mm512_set1_ps(d2);
602  const __m512 vm2 = _mm512_set1_ps(m2);
603 
604  const uint8_t *qs = &block->qs[iter * 32];
605  float *dxp = &dX[b * QK_K + iter * 64];
606 
607  /* Process first 32 weights (low nibbles) */
608  for (int chunk = 0; chunk < 2; chunk++) {
609  __m128i packed = _mm_loadu_si128((const __m128i *)&qs[chunk * 16]);
610  __m512i bytes = _mm512_cvtepu8_epi32(packed);
611  __m512i lo = _mm512_and_epi32(bytes, mask_lo);
612  /* w = d1 * q - m1 */
613  __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vm1,
614  _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vd1));
615  __m512 grad = _mm512_mul_ps(w, vdy);
616  __m512 existing = _mm512_loadu_ps(&dxp[chunk * 16]);
617  _mm512_storeu_ps(&dxp[chunk * 16], _mm512_add_ps(existing, grad));
618  }
619 
620  /* Process next 32 weights (high nibbles) */
621  for (int chunk = 0; chunk < 2; chunk++) {
622  __m128i packed = _mm_loadu_si128((const __m128i *)&qs[chunk * 16]);
623  __m512i bytes = _mm512_cvtepu8_epi32(packed);
624  __m512i hi = _mm512_srli_epi32(bytes, 4);
625  /* w = d2 * q - m2 */
626  __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vm2,
627  _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vd2));
628  __m512 grad = _mm512_mul_ps(w, vdy);
629  __m512 existing = _mm512_loadu_ps(&dxp[32 + chunk * 16]);
630  _mm512_storeu_ps(&dxp[32 + chunk * 16], _mm512_add_ps(existing, grad));
631  }
632  }
633  }
634  }
635 }
636 #endif
637 
638 /**
639  * @brief Auto-dispatch backward
640  */
641 void gemv_q4_k_backward(float *dX,
642  const void *W,
643  const float *dY,
644  int M, int K)
645 {
646 #ifdef __AVX512F__
647  gemv_q4_k_backward_avx512(dX, W, dY, M, K);
648 #else
649  gemv_q4_k_backward_ref(dX, W, dY, M, K);
650 #endif
651 }
652 
653 /**
654  * @brief Batched backward pass
655  */
656 void gemm_q4_k_backward(float *dX,
657  const void *W,
658  const float *dY,
659  int M, int N, int K)
660 {
661  for (int n = 0; n < N; n++) {
662  gemv_q4_k_backward(&dX[n * K], W, &dY[n * M], M, K);
663  }
664 }
665 
666 /* ============================================================================
667  * Engine-compatible wrapper: GEMM_NT with Q4_K weights
668  *
669  * The core q4_k kernels in this file use the convention:
670  * - W: [M_out x K] (quantized row-major)
671  * - X: [N_batch x K] (fp32)
672  * - Y: [N_batch x M_out] (fp32)
673  *
674  * The C-Kernel-Engine convention for NN weights uses:
675  * - A: [M_tokens x K] (fp32)
676  * - B: [N_out x K] (quantized row-major, transposed layout)
677  * - C: [M_tokens x N_out] (fp32)
678  *
679  * This wrapper swaps (M_out, N_batch) to match the engine layout and applies
680  * an optional bias.
681  * ============================================================================ */
682 
683 void gemm_nt_q4_k(const float *A,
684  const void *B,
685  const float *bias,
686  float *C,
687  int M, int N, int K)
688 {
689  if (!A || !B || !C) {
690  return;
691  }
692  if (M <= 0 || N <= 0 || K <= 0) {
693  return;
694  }
695 
696  /* gemm_q4_k produces Y as [batch x M_out]. Here:
697  * batch = M (tokens)
698  * M_out = N (output channels) */
699  gemm_q4_k(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
700 
701  if (!bias) {
702  return;
703  }
704 
705  for (int i = 0; i < M; ++i) {
706  float *row = C + (size_t)i * (size_t)N;
707  for (int j = 0; j < N; ++j) {
708  row[j] += bias[j];
709  }
710  }
711 }
Quantization block structures for weight-only quantization.
#define GGML_FP16_TO_FP32
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
#define QK_K
void gemm_q4_k_ref(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_K weights (scalar reference)
float dot_q4_k(const void *w_q4k, const float *x, int K)
Compute dot product of Q4_K row with FP32 vector.
void gemm_q4_k_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q4_k_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q4_k_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_K weights (scalar reference)
void gemv_q4_k_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
void gemm_q4_k(float *Y, const void *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
#define C(color)
Definition: show_config.c:39
uint8_t scales[12]
uint8_t qs[256/2]
ck_half dmin