← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k_q8k.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q4k_q8k.c
3  * @brief Q4_K (weights) x Q8_K (activations) kernels for inference
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Implements decode-style matvec/matmul where weights are Q4_K and the
15  * activations are quantized on-the-fly to Q8_K. This is inference-only;
16  * no backward pass is provided here.
17  */
18 
19 #include <assert.h>
20 #include <math.h>
21 #include <string.h>
22 
23 #include "ckernel_quant.h"
24 
25 void gemv_q4_k_q8_k_avx2(float *y,
26  const void *W,
27  const void *x_q8,
28  int M, int K);
29 
30 void gemv_q4_k_q8_k_vnni(float *y,
31  const void *W,
32  const void *x_q8,
33  int M, int K);
34 
35 void gemv_q4_k_q8_k_avx(float *y,
36  const void *W,
37  const void *x_q8,
38  int M, int K);
39 
40 void gemv_q4_k_q8_k_sse(float *y,
41  const void *W,
42  const void *x_q8,
43  int M, int K);
44 
45 static inline int ck_nearest_int(float fval) {
46  /* Bit-level round-to-nearest from llama.cpp (fast + deterministic). */
47  float val = fval + 12582912.f;
48  int i;
49  memcpy(&i, &val, sizeof(int));
50  return (i & 0x007fffff) - 0x00400000;
51 }
52 
53 void quantize_row_q8_k_ref(const float *x, void *vy, int k) {
54  if (!x || !vy || k <= 0) {
55  return;
56  }
57  assert(k % QK_K == 0);
58  const int nb = k / QK_K;
59  block_q8_K *y = (block_q8_K *)vy;
60 
61  for (int i = 0; i < nb; ++i) {
62  float max = 0.0f;
63  float amax = 0.0f;
64  for (int j = 0; j < QK_K; ++j) {
65  float ax = fabsf(x[j]);
66  if (ax > amax) {
67  amax = ax;
68  max = x[j];
69  }
70  }
71  if (!amax) {
72  y[i].d = 0.0f;
73  memset(y[i].qs, 0, sizeof(y[i].qs));
74  memset(y[i].bsums, 0, sizeof(y[i].bsums));
75  x += QK_K;
76  continue;
77  }
78 
79  const float iscale = -127.0f / max;
80  for (int j = 0; j < QK_K; ++j) {
81  int v = ck_nearest_int(iscale * x[j]);
82  if (v > 127) {
83  v = 127;
84  }
85  if (v < -128) {
86  v = -128;
87  }
88  y[i].qs[j] = (int8_t)v;
89  }
90 
91  for (int j = 0; j < QK_K / 16; ++j) {
92  int sum = 0;
93  const int8_t *qs = &y[i].qs[j * 16];
94  for (int ii = 0; ii < 16; ++ii) {
95  sum += qs[ii];
96  }
97  y[i].bsums[j] = (int16_t)sum;
98  }
99 
100  y[i].d = 1.0f / iscale;
101  x += QK_K;
102  }
103 }
104 
105 void quantize_row_q8_k_sse(const float *x, void *vy, int k);
106 
107 void quantize_row_q8_k(const float *x, void *vy, int k) {
108 #if defined(__SSE4_1__)
109  quantize_row_q8_k_sse(x, vy, k);
110 #else
111  quantize_row_q8_k_ref(x, vy, k);
112 #endif
113 }
114 
115 static float dot_q4_k_q8_k_ref(const block_q4_K *w,
116  const block_q8_K *x,
117  int k)
118 {
119  const int nb = k / QK_K;
120  float sumf = 0.0f;
121 
122  for (int i = 0; i < nb; ++i) {
123  uint8_t sc[8], m_val[8];
124  unpack_q4_k_scales(w[i].scales, sc, m_val);
125 
126  const float d = CK_FP16_TO_FP32(w[i].d) * x[i].d;
127  const float dmin = CK_FP16_TO_FP32(w[i].dmin) * x[i].d;
128 
129  /* Q4_K layout: process 64 elements at a time
130  * - Low nibbles of qs[0..31] → elements 0..31 → uses sc[0], m[0]
131  * - High nibbles of qs[0..31] → elements 32..63 → uses sc[1], m[1]
132  * - Low nibbles of qs[32..63] → elements 64..95 → uses sc[2], m[2]
133  * - etc.
134  */
135  int is = 0;
136  int q_offset = 0;
137 
138  for (int j = 0; j < QK_K; j += 64) {
139  const uint8_t *qs = &w[i].qs[q_offset];
140  const int8_t *q8_lo = &x[i].qs[j]; /* Elements j to j+31 */
141  const int8_t *q8_hi = &x[i].qs[j + 32]; /* Elements j+32 to j+63 */
142 
143  /* Sum for low nibbles (elements j to j+31) */
144  int32_t sum_q4q8_lo = 0;
145  for (int l = 0; l < 32; ++l) {
146  int q4_val = qs[l] & 0x0F;
147  sum_q4q8_lo += q4_val * q8_lo[l];
148  }
149 
150  /* Sum for high nibbles (elements j+32 to j+63) */
151  int32_t sum_q4q8_hi = 0;
152  for (int l = 0; l < 32; ++l) {
153  int q4_val = qs[l] >> 4;
154  sum_q4q8_hi += q4_val * q8_hi[l];
155  }
156 
157  /* bsums: each bsum is 16 elements */
158  int32_t bsum_lo = (int32_t)x[i].bsums[j / 16] +
159  (int32_t)x[i].bsums[j / 16 + 1];
160  int32_t bsum_hi = (int32_t)x[i].bsums[(j + 32) / 16] +
161  (int32_t)x[i].bsums[(j + 32) / 16 + 1];
162 
163  /* Accumulate: d * sc * sum(q4*q8) - dmin * m * sum(q8) */
164  sumf += d * (float)sc[is] * (float)sum_q4q8_lo;
165  sumf -= dmin * (float)m_val[is] * (float)bsum_lo;
166  sumf += d * (float)sc[is + 1] * (float)sum_q4q8_hi;
167  sumf -= dmin * (float)m_val[is + 1] * (float)bsum_hi;
168 
169  q_offset += 32;
170  is += 2;
171  }
172  }
173 
174  return sumf;
175 }
176 
177 void gemv_q4_k_q8_k_ref(float *y,
178  const void *W,
179  const void *x_q8,
180  int M, int K)
181 {
182  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
183  return;
184  }
185 
186  const block_q4_K *blocks = (const block_q4_K *)W;
187  const block_q8_K *x = (const block_q8_K *)x_q8;
188  const int blocks_per_row = K / QK_K;
189 
190  for (int row = 0; row < M; ++row) {
191  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
192  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
193  }
194 }
195 
196 /* ============================================================================
197  * PARALLEL VERSIONS (for parallel orchestration)
198  *
199  * These receive ith (thread index) and nth (total threads) from orchestration.
200  * OpenMP lives in orchestration layer, NOT here.
201  *
202  * Naming: *_parallel = receives ith/nth, processes only its portion
203  * *_ref/_avx = single-threaded, processes all rows
204  * ============================================================================ */
205 
207  const void *W,
208  const void *x_q8,
209  int M, int K,
210  int ith, int nth)
211 {
212  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
213  return;
214  }
215  if (ith < 0 || nth <= 0 || ith >= nth) {
216  return;
217  }
218 
219  /* Compute row range for this thread */
220  const int dr = (M + nth - 1) / nth;
221  const int r0 = dr * ith;
222  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
223 
224  if (r0 >= M) {
225  return; /* This thread has no work */
226  }
227 
228  const block_q4_K *blocks = (const block_q4_K *)W;
229  const block_q8_K *x = (const block_q8_K *)x_q8;
230  const int blocks_per_row = K / QK_K;
231 
232  /* Only process rows [r0, r1) */
233  for (int row = r0; row < r1; ++row) {
234  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
235  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
236  }
237 }
238 
239 void gemv_q4_k_q8_k(float *y,
240  const void *W,
241  const void *x_q8,
242  int M, int K)
243 {
244 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
245  /* VNNI: Best for decode (single token) - INT8 dot product acceleration */
246  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
247 #elif defined(__AVX2__)
248  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
249 #elif defined(__AVX__)
250  /* AVX version uses maddubs_epi16 (more efficient than SSE) */
251  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
252 #elif defined(__SSE4_1__)
253  gemv_q4_k_q8_k_sse(y, W, x_q8, M, K);
254 #else
255  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
256 #endif
257 }
258 
259 void gemm_q4_k_q8_k_ref(float *Y,
260  const void *W,
261  const void *X_q8,
262  int M, int N, int K)
263 {
264  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
265  return;
266  }
267 
268  const block_q8_K *X = (const block_q8_K *)X_q8;
269  const int blocks_per_vec = K / QK_K;
270 
271  for (int n = 0; n < N; ++n) {
272  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
273  gemv_q4_k_q8_k_ref(&Y[n * M], W, x_row, M, K);
274  }
275 }
276 
277 void gemm_q4_k_q8_k(float *Y,
278  const void *W,
279  const void *X_q8,
280  int M, int N, int K)
281 {
282  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
283  return;
284  }
285 
286  const block_q8_K *X = (const block_q8_K *)X_q8;
287  const int blocks_per_vec = K / QK_K;
288 
289  for (int n = 0; n < N; ++n) {
290  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
291  gemv_q4_k_q8_k(&Y[n * M], W, x_row, M, K);
292  }
293 }
294 
295 void gemm_nt_q4_k_q8_k(const void *A_q8,
296  const void *B,
297  const float *bias,
298  float *C,
299  int M, int N, int K)
300 {
301  if (!A_q8 || !B || !C) {
302  return;
303  }
304  if (M <= 0 || N <= 0 || K <= 0) {
305  return;
306  }
307 
308  gemm_q4_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
309 
310  if (!bias) {
311  return;
312  }
313 
314  for (int i = 0; i < M; ++i) {
315  float *row = C + (size_t)i * (size_t)N;
316  for (int j = 0; j < N; ++j) {
317  row[j] += bias[j];
318  }
319  }
320 }
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
#define QK_K
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void quantize_row_q8_k(const float *x, void *vy, int k)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void gemm_q4_k_q8_k_ref(float *Y, const void *W, const void *X_q8, int M, int N, int K)
static int ck_nearest_int(float fval)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemm_q4_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
static float dot_q4_k_q8_k_ref(const block_q4_K *w, const block_q8_K *x, int k)
void quantize_row_q8_k_ref(const float *x, void *vy, int k)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)
#define C(color)
Definition: show_config.c:39
uint8_t qs[256/2]
int8_t qs[256]
int16_t bsums[256/16]