← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q8_0.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q8_0.c
3  * @brief GEMM/GEMV kernels with Q8_0 quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Q8_0 Format:
15  * - 32 weights per block
16  * - 1 FP16 scale per block
17  * - 34 bytes per 32 weights = 8.5 bits/weight
18  * - Weights stored as signed 8-bit integers
19  *
20  * Operations:
21  * Forward: Y = W @ X (W is Q8_0, X and Y are FP32)
22  * Backward: dX = W^T @ dY (gradient w.r.t. input)
23  *
24  * Note: Q8_0 is often used for activation quantization or as an
25  * intermediate format. Higher precision than Q4_0/Q4_K.
26  */
27 
28 #include <stdint.h>
29 #include <stddef.h>
30 #include <string.h>
31 #include "ckernel_quant.h"
32 #include "ck_features.h"
33 
34 /* Include SIMD headers based on available extensions */
35 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
36 #include <immintrin.h>
37 #endif
38 
39 void quantize_row_q8_k(const float *x, void *vy, int k);
40 
41 /* ============================================================================
42  * Q8_0 Quantization
43  *
44  * Quantizes FP32 values to Q8_0 format (32 elements per block).
45  * Each block has:
46  * - 1 FP16 scale (computed as max abs value / 127)
47  * - 32 int8 quantized values
48  *
49  * This matches llama.cpp's quantize_row_q8_0.
50  * ============================================================================ */
51 
52 /**
53  * @brief Quantize FP32 to Q8_0 format (scalar reference)
54  *
55  * @param x Input FP32 values
56  * @param vy Output Q8_0 blocks
57  * @param k Number of elements (must be multiple of 32)
58  */
59 void quantize_row_q8_0(const float *x, void *vy, int k)
60 {
61  block_q8_0 *y = (block_q8_0 *)vy;
62  const int nb = k / QK8_0; /* QK8_0 = 32 */
63 
64 #if defined(__AVX__)
65  const __m256 sign_bit = _mm256_set1_ps(-0.0f);
66  const __m256 v_half = _mm256_set1_ps(0.5f);
67  const __m256 v_min = _mm256_set1_ps(-127.0f);
68  const __m256 v_max = _mm256_set1_ps(127.0f);
69 
70  for (int i = 0; i < nb; i++) {
71  __m256 v0 = _mm256_loadu_ps(x + 0);
72  __m256 v1 = _mm256_loadu_ps(x + 8);
73  __m256 v2 = _mm256_loadu_ps(x + 16);
74  __m256 v3 = _mm256_loadu_ps(x + 24);
75  x += QK8_0;
76 
77  __m256 max_abs = _mm256_andnot_ps(sign_bit, v0);
78  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v1));
79  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v2));
80  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v3));
81 
82  __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
83  _mm256_castps256_ps128(max_abs));
84  max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
85  max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
86  const float max_scalar = _mm_cvtss_f32(max4);
87 
88  const float d = max_scalar / 127.0f;
89  const float id = max_scalar != 0.0f ? 127.0f / max_scalar : 0.0f;
90  y[i].d = CK_FP32_TO_FP16(d);
91 
92  const __m256 mul = _mm256_set1_ps(id);
93  v0 = _mm256_mul_ps(v0, mul);
94  v1 = _mm256_mul_ps(v1, mul);
95  v2 = _mm256_mul_ps(v2, mul);
96  v3 = _mm256_mul_ps(v3, mul);
97 
98  v0 = _mm256_min_ps(_mm256_max_ps(v0, v_min), v_max);
99  v1 = _mm256_min_ps(_mm256_max_ps(v1, v_min), v_max);
100  v2 = _mm256_min_ps(_mm256_max_ps(v2, v_min), v_max);
101  v3 = _mm256_min_ps(_mm256_max_ps(v3, v_min), v_max);
102 
103  /* Round half away from zero to match the scalar path */
104  v0 = _mm256_add_ps(v0, _mm256_or_ps(_mm256_and_ps(v0, sign_bit), v_half));
105  v1 = _mm256_add_ps(v1, _mm256_or_ps(_mm256_and_ps(v1, sign_bit), v_half));
106  v2 = _mm256_add_ps(v2, _mm256_or_ps(_mm256_and_ps(v2, sign_bit), v_half));
107  v3 = _mm256_add_ps(v3, _mm256_or_ps(_mm256_and_ps(v3, sign_bit), v_half));
108 
109  __m256i i0 = _mm256_cvttps_epi32(v0);
110  __m256i i1 = _mm256_cvttps_epi32(v1);
111  __m256i i2 = _mm256_cvttps_epi32(v2);
112  __m256i i3 = _mm256_cvttps_epi32(v3);
113 
114 #if defined(__AVX2__)
115  i0 = _mm256_packs_epi32(i0, i1);
116  i2 = _mm256_packs_epi32(i2, i3);
117  i0 = _mm256_packs_epi16(i0, i2);
118 
119  const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
120  i0 = _mm256_permutevar8x32_epi32(i0, perm);
121  _mm256_storeu_si256((__m256i *)y[i].qs, i0);
122 #else
123  __m128i ni0 = _mm256_castsi256_si128(i0);
124  __m128i ni1 = _mm256_extractf128_si256(i0, 1);
125  __m128i ni2 = _mm256_castsi256_si128(i1);
126  __m128i ni3 = _mm256_extractf128_si256(i1, 1);
127  __m128i ni4 = _mm256_castsi256_si128(i2);
128  __m128i ni5 = _mm256_extractf128_si256(i2, 1);
129  __m128i ni6 = _mm256_castsi256_si128(i3);
130  __m128i ni7 = _mm256_extractf128_si256(i3, 1);
131 
132  ni0 = _mm_packs_epi32(ni0, ni1);
133  ni2 = _mm_packs_epi32(ni2, ni3);
134  ni4 = _mm_packs_epi32(ni4, ni5);
135  ni6 = _mm_packs_epi32(ni6, ni7);
136 
137  ni0 = _mm_packs_epi16(ni0, ni2);
138  ni4 = _mm_packs_epi16(ni4, ni6);
139 
140  _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
141  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
142 #endif
143  }
144 #else
145  for (int i = 0; i < nb; i++) {
146  const float *xb = x + i * QK8_0;
147 
148  /* Find max absolute value in block */
149  float amax = 0.0f;
150  for (int j = 0; j < QK8_0; j++) {
151  float av = xb[j] >= 0 ? xb[j] : -xb[j];
152  if (av > amax) amax = av;
153  }
154 
155  /* Compute scale: d = max / 127 */
156  float d = amax / 127.0f;
157  float id = d != 0.0f ? 127.0f / amax : 0.0f;
158 
159  /* Store scale as FP16 */
160  y[i].d = CK_FP32_TO_FP16(d);
161 
162  /* Quantize values */
163  for (int j = 0; j < QK8_0; j++) {
164  float v = xb[j] * id;
165  /* Round to nearest int and clamp to [-127, 127] */
166  int q = (int)(v + (v >= 0 ? 0.5f : -0.5f));
167  if (q > 127) q = 127;
168  if (q < -127) q = -127;
169  y[i].qs[j] = (int8_t)q;
170  }
171  }
172 #endif
173 }
174 
175 /**
176  * @brief Batch quantize FP32 to Q8_0 format (row-major output)
177  *
178  * Quantizes multiple rows of FP32 data to Q8_0 format, placing each row's
179  * Q8_0 output at the correct byte offset for GEMM compatibility.
180  *
181  * Memory layout:
182  * Input: [num_rows, k] FP32, row-major (stride = k * sizeof(float))
183  * Output: [num_rows, q8_row_bytes] Q8_0, row-major (stride = q8_row_bytes)
184  *
185  * where q8_row_bytes = (k / 32) * sizeof(block_q8_0) = (k / 32) * 34
186  *
187  * @param x Input FP32 values [num_rows * k]
188  * @param vy Output Q8_0 blocks [num_rows * (k/32) blocks]
189  * @param num_rows Number of rows (batch size / tokens)
190  * @param k Elements per row (must be multiple of 32)
191  */
192 void quantize_batch_q8_0(const float *x, void *vy, int num_rows, int k)
193 {
194  const size_t row_bytes_in = (size_t)k * sizeof(float);
195  const size_t row_bytes_out = (size_t)(k / QK8_0) * sizeof(block_q8_0);
196 
197  uint8_t *out = (uint8_t *)vy;
198  const uint8_t *in = (const uint8_t *)x;
199 
200  for (int row = 0; row < num_rows; ++row) {
202  (const float *)(in + row * row_bytes_in),
203  (void *)(out + row * row_bytes_out),
204  k
205  );
206  }
207 }
208 
209 /**
210  * @brief Batch quantize FP32 to Q8_K format (row-major output)
211  *
212  * Same as quantize_batch_q8_0 but for Q8_K format (super-blocks).
213  *
214  * @param x Input FP32 values [num_rows * k]
215  * @param vy Output Q8_K blocks
216  * @param num_rows Number of rows (batch size / tokens)
217  * @param k Elements per row (must be multiple of 256)
218  */
219 void quantize_batch_q8_k(const float *x, void *vy, int num_rows, int k)
220 {
221  /* Q8_K: 256 elements per super-block, each block is larger */
222  const size_t row_bytes_in = (size_t)k * sizeof(float);
223  /* Q8_K block size = 2 (d) + 256 (qs) + 32 (bsums/2) = ~274 bytes for 256 elements */
224  /* Actual: sizeof(block_q8_K) from ckernel_quant.h */
225  const size_t row_bytes_out = (size_t)(k / 256) * sizeof(block_q8_K);
226 
227  uint8_t *out = (uint8_t *)vy;
228  const uint8_t *in = (const uint8_t *)x;
229 
230  for (int row = 0; row < num_rows; ++row) {
232  (const float *)(in + row * row_bytes_in),
233  (void *)(out + row * row_bytes_out),
234  k
235  );
236  }
237 }
238 
239 /* ============================================================================
240  * Forward Pass: GEMV y = W @ x
241  * ============================================================================ */
242 
243 /**
244  * @brief Matrix-vector multiply with Q8_0 weights (scalar reference)
245  *
246  * @param y Output vector [M]
247  * @param W Weight matrix in Q8_0 format [M x K]
248  * @param x Input vector [K]
249  * @param M Number of output rows
250  * @param K Number of columns (must be multiple of 32)
251  */
252 void gemv_q8_0_ref(float *y,
253  const void *W,
254  const float *x,
255  int M, int K)
256 {
257  const block_q8_0 *blocks = (const block_q8_0 *)W;
258  const int blocks_per_row = K / QK8_0;
259 
260  for (int row = 0; row < M; row++) {
261  float sum = 0.0f;
262 
263  for (int b = 0; b < blocks_per_row; b++) {
264  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
265  const float d = CK_FP16_TO_FP32(block->d);
266  const float *xp = &x[b * QK8_0];
267 
268  for (int i = 0; i < QK8_0; i++) {
269  sum += d * (float)block->qs[i] * xp[i];
270  }
271  }
272 
273  y[row] = sum;
274  }
275 }
276 
277 #ifdef __AVX512F__
278 /**
279  * @brief Matrix-vector multiply with Q8_0 weights (AVX-512)
280  */
281 void gemv_q8_0_avx512(float *y,
282  const void *W,
283  const float *x,
284  int M, int K)
285 {
286  const block_q8_0 *blocks = (const block_q8_0 *)W;
287  const int blocks_per_row = K / QK8_0;
288 
289  for (int row = 0; row < M; row++) {
290  __m512 acc = _mm512_setzero_ps();
291 
292  for (int b = 0; b < blocks_per_row; b++) {
293  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
294  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
295  const float *xp = &x[b * QK8_0];
296 
297  /* Process 32 weights in two batches of 16 */
298  for (int chunk = 0; chunk < 2; chunk++) {
299  /* Load 16 x int8 weights */
300  __m128i q8 = _mm_loadu_si128((const __m128i *)&block->qs[chunk * 16]);
301 
302  /* Sign-extend to 32-bit */
303  __m512i q32 = _mm512_cvtepi8_epi32(q8);
304 
305  /* Convert to float and scale */
306  __m512 w = _mm512_mul_ps(_mm512_cvtepi32_ps(q32), vscale);
307 
308  /* Load input */
309  __m512 x_vec = _mm512_loadu_ps(&xp[chunk * 16]);
310 
311  /* FMA */
312  acc = _mm512_fmadd_ps(w, x_vec, acc);
313  }
314  }
315 
316  y[row] = _mm512_reduce_add_ps(acc);
317  }
318 }
319 #endif
320 
321 /* ============================================================================
322  * AVX2 Implementation (Haswell+, 256-bit integer operations)
323  *
324  * Q8_0 format: 32 signed int8 weights per block
325  * - d: FP16 scale
326  * - qs: 32 int8 weights
327  * - Dequant: w = d * q
328  *
329  * AVX2 provides _mm256_cvtepi8_epi32 for efficient 8-to-32 sign extension.
330  * Processes 8 weights at a time with full 256-bit FMA.
331  * ============================================================================ */
332 
333 #if defined(__AVX2__) && !defined(__AVX512F__)
334 
335 /* Helper: AVX2 horizontal sum of 8 floats */
336 static inline float hsum_avx2_q8(__m256 v) {
337  __m128 lo = _mm256_castps256_ps128(v);
338  __m128 hi = _mm256_extractf128_ps(v, 1);
339  lo = _mm_add_ps(lo, hi); /* 4 floats */
340  __m128 shuf = _mm_shuffle_ps(lo, lo, _MM_SHUFFLE(2, 3, 0, 1));
341  __m128 sums = _mm_add_ps(lo, shuf);
342  shuf = _mm_movehl_ps(shuf, sums);
343  sums = _mm_add_ss(sums, shuf);
344  return _mm_cvtss_f32(sums);
345 }
346 
347 /**
348  * @brief Matrix-vector multiply with Q8_0 weights (AVX2 optimized)
349  *
350  * Uses AVX2's _mm256_cvtepi8_epi32 for efficient sign extension.
351  * Processes 8 weights at a time with FMA.
352  */
353 void gemv_q8_0_avx2(float *y,
354  const void *W,
355  const float *x,
356  int M, int K)
357 {
358  const block_q8_0 *blocks = (const block_q8_0 *)W;
359  const int blocks_per_row = K / QK8_0; /* QK8_0 = 32 */
360 
361  for (int row = 0; row < M; row++) {
362  __m256 acc = _mm256_setzero_ps();
363 
364  for (int b = 0; b < blocks_per_row; b++) {
365  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
366  const float d = CK_FP16_TO_FP32(block->d);
367  const __m256 vscale = _mm256_set1_ps(d);
368  const float *xp = &x[b * QK8_0];
369 
370  /* Process 32 weights in 4 groups of 8 using AVX2 */
371 
372  /* Group 0: weights 0-7 */
373  {
374  __m128i q8 = _mm_loadl_epi64((const __m128i *)&block->qs[0]);
375  __m256i q32 = _mm256_cvtepi8_epi32(q8);
376  __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
377  __m256 xv = _mm256_loadu_ps(&xp[0]);
378  acc = _mm256_fmadd_ps(wf, xv, acc);
379  }
380 
381  /* Group 1: weights 8-15 */
382  {
383  __m128i q8 = _mm_loadl_epi64((const __m128i *)&block->qs[8]);
384  __m256i q32 = _mm256_cvtepi8_epi32(q8);
385  __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
386  __m256 xv = _mm256_loadu_ps(&xp[8]);
387  acc = _mm256_fmadd_ps(wf, xv, acc);
388  }
389 
390  /* Group 2: weights 16-23 */
391  {
392  __m128i q8 = _mm_loadl_epi64((const __m128i *)&block->qs[16]);
393  __m256i q32 = _mm256_cvtepi8_epi32(q8);
394  __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
395  __m256 xv = _mm256_loadu_ps(&xp[16]);
396  acc = _mm256_fmadd_ps(wf, xv, acc);
397  }
398 
399  /* Group 3: weights 24-31 */
400  {
401  __m128i q8 = _mm_loadl_epi64((const __m128i *)&block->qs[24]);
402  __m256i q32 = _mm256_cvtepi8_epi32(q8);
403  __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
404  __m256 xv = _mm256_loadu_ps(&xp[24]);
405  acc = _mm256_fmadd_ps(wf, xv, acc);
406  }
407  }
408 
409  y[row] = hsum_avx2_q8(acc);
410  }
411 }
412 #endif /* __AVX2__ && !__AVX512F__ */
413 
414 /* ============================================================================
415  * AVX Implementation with True SIMD (256-bit float + 128-bit integer)
416  *
417  * Q8_0 format: 32 signed int8 weights per block
418  * - d: FP16 scale
419  * - qs: 32 int8 weights
420  * - Dequant: w = d * q
421  *
422  * This is much simpler than Q5_0 since weights are already in int8 format.
423  * We use SSE for integer-to-float conversion and AVX for accumulation.
424  * ============================================================================ */
425 
426 #if defined(__AVX__) && !defined(__AVX2__) && !defined(__AVX512F__)
427 
428 /* Helper: SSE horizontal sum of 4 floats */
429 static inline float hsum_sse_q8(__m128 v) {
430  __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1));
431  __m128 sums = _mm_add_ps(v, shuf);
432  shuf = _mm_movehl_ps(shuf, sums);
433  sums = _mm_add_ss(sums, shuf);
434  return _mm_cvtss_f32(sums);
435 }
436 
437 /**
438  * @brief Matrix-vector multiply with Q8_0 weights (AVX + SSE optimized)
439  *
440  * Uses full SIMD: SSE for int8->float conversion, SSE/AVX for dot product.
441  * ~4-6x faster than scalar reference on Ivy Bridge.
442  */
443 void gemv_q8_0_avx(float *y,
444  const void *W,
445  const float *x,
446  int M, int K)
447 {
448  const block_q8_0 *blocks = (const block_q8_0 *)W;
449  const int blocks_per_row = K / QK8_0; /* QK8_0 = 32 */
450 
451  for (int row = 0; row < M; row++) {
452  /* Use 4 SSE accumulators for ILP */
453  __m128 acc0 = _mm_setzero_ps();
454  __m128 acc1 = _mm_setzero_ps();
455  __m128 acc2 = _mm_setzero_ps();
456  __m128 acc3 = _mm_setzero_ps();
457 
458  for (int b = 0; b < blocks_per_row; b++) {
459  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
460  const float d = CK_FP16_TO_FP32(block->d);
461  const float *xp = &x[b * QK8_0];
462  const __m128 vscale = _mm_set1_ps(d);
463 
464  /* Load 32 int8 weights in 2 SSE loads of 16 bytes each */
465  __m128i q8_0 = _mm_loadu_si128((const __m128i *)&block->qs[0]);
466  __m128i q8_1 = _mm_loadu_si128((const __m128i *)&block->qs[16]);
467 
468  /* Process first 16 weights: convert int8 -> int16 -> int32 -> float */
469  /* Chunk 0: weights 0-3 */
470  {
471  __m128i q16 = _mm_cvtepi8_epi16(q8_0); /* 8 int8 -> 8 int16 */
472  __m128i q32 = _mm_cvtepi16_epi32(q16); /* 4 int16 -> 4 int32 */
473  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
474  __m128 vx = _mm_loadu_ps(&xp[0]);
475  acc0 = _mm_add_ps(acc0, _mm_mul_ps(w, vx));
476  }
477 
478  /* Chunk 1: weights 4-7 */
479  {
480  __m128i q16 = _mm_cvtepi8_epi16(q8_0);
481  __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
482  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
483  __m128 vx = _mm_loadu_ps(&xp[4]);
484  acc1 = _mm_add_ps(acc1, _mm_mul_ps(w, vx));
485  }
486 
487  /* Chunk 2: weights 8-11 */
488  {
489  __m128i q8_shifted = _mm_srli_si128(q8_0, 8); /* shift right 8 bytes */
490  __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
491  __m128i q32 = _mm_cvtepi16_epi32(q16);
492  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
493  __m128 vx = _mm_loadu_ps(&xp[8]);
494  acc2 = _mm_add_ps(acc2, _mm_mul_ps(w, vx));
495  }
496 
497  /* Chunk 3: weights 12-15 */
498  {
499  __m128i q8_shifted = _mm_srli_si128(q8_0, 8);
500  __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
501  __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
502  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
503  __m128 vx = _mm_loadu_ps(&xp[12]);
504  acc3 = _mm_add_ps(acc3, _mm_mul_ps(w, vx));
505  }
506 
507  /* Process second 16 weights (16-31) */
508  /* Chunk 4: weights 16-19 */
509  {
510  __m128i q16 = _mm_cvtepi8_epi16(q8_1);
511  __m128i q32 = _mm_cvtepi16_epi32(q16);
512  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
513  __m128 vx = _mm_loadu_ps(&xp[16]);
514  acc0 = _mm_add_ps(acc0, _mm_mul_ps(w, vx));
515  }
516 
517  /* Chunk 5: weights 20-23 */
518  {
519  __m128i q16 = _mm_cvtepi8_epi16(q8_1);
520  __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
521  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
522  __m128 vx = _mm_loadu_ps(&xp[20]);
523  acc1 = _mm_add_ps(acc1, _mm_mul_ps(w, vx));
524  }
525 
526  /* Chunk 6: weights 24-27 */
527  {
528  __m128i q8_shifted = _mm_srli_si128(q8_1, 8);
529  __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
530  __m128i q32 = _mm_cvtepi16_epi32(q16);
531  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
532  __m128 vx = _mm_loadu_ps(&xp[24]);
533  acc2 = _mm_add_ps(acc2, _mm_mul_ps(w, vx));
534  }
535 
536  /* Chunk 7: weights 28-31 */
537  {
538  __m128i q8_shifted = _mm_srli_si128(q8_1, 8);
539  __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
540  __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
541  __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
542  __m128 vx = _mm_loadu_ps(&xp[28]);
543  acc3 = _mm_add_ps(acc3, _mm_mul_ps(w, vx));
544  }
545  }
546 
547  /* Combine accumulators and reduce */
548  __m128 sum01 = _mm_add_ps(acc0, acc1);
549  __m128 sum23 = _mm_add_ps(acc2, acc3);
550  __m128 sum = _mm_add_ps(sum01, sum23);
551 
552  y[row] = hsum_sse_q8(sum);
553  }
554 }
555 #endif /* __AVX__ && !__AVX512F__ */
556 
557 #if defined(__SSE4_1__)
558 #include <immintrin.h>
559 
560 /* Helper macro: extract 4 int8 weights at byte offset, convert to float, multiply with x */
561 #define SSE_Q8_BLOCK(q8_reg, offset, xp, d_val, acc) do { \
562  __m128 vx = _mm_loadu_ps(&(xp)[offset]); \
563  __m128i qw = _mm_cvtepi8_epi32(_mm_srli_si128(q8_reg, offset)); \
564  __m128 vw = _mm_cvtepi32_ps(qw); \
565  acc = _mm_add_ps(acc, _mm_mul_ps(_mm_mul_ps(vw, vx), _mm_set1_ps(d_val))); \
566 } while(0)
567 
568 void gemv_q8_0_sse(float *y,
569  const void *W,
570  const float *x,
571  int M, int K)
572 {
573  const block_q8_0 *blocks = (const block_q8_0 *)W;
574  const int blocks_per_row = K / QK8_0;
575 
576  for (int row = 0; row < M; row++) {
577  __m128 acc = _mm_setzero_ps();
578 
579  for (int b = 0; b < blocks_per_row; b++) {
580  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
581  const float d_val = CK_FP16_TO_FP32(block->d);
582  const float *xp = &x[b * QK8_0];
583 
584  /* Load 32 weights (signed 8-bit) in two 16-byte chunks */
585  __m128i q8_0 = _mm_loadu_si128((const __m128i *)&block->qs[0]);
586  __m128i q8_1 = _mm_loadu_si128((const __m128i *)&block->qs[16]);
587 
588  /* Process first 16 weights (q8_0) - unrolled with compile-time constants */
589  SSE_Q8_BLOCK(q8_0, 0, xp, d_val, acc);
590  SSE_Q8_BLOCK(q8_0, 4, xp, d_val, acc);
591  SSE_Q8_BLOCK(q8_0, 8, xp, d_val, acc);
592  SSE_Q8_BLOCK(q8_0, 12, xp, d_val, acc);
593 
594  /* Process second 16 weights (q8_1) - offset xp by 16 */
595  const float *xp1 = xp + 16;
596  SSE_Q8_BLOCK(q8_1, 0, xp1, d_val, acc);
597  SSE_Q8_BLOCK(q8_1, 4, xp1, d_val, acc);
598  SSE_Q8_BLOCK(q8_1, 8, xp1, d_val, acc);
599  SSE_Q8_BLOCK(q8_1, 12, xp1, d_val, acc);
600  }
601 
602  /* Horizontal sum */
603  acc = _mm_add_ps(acc, _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(1, 0, 3, 2)));
604  acc = _mm_add_ps(acc, _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(0, 1, 0, 1)));
605  _mm_store_ss(&y[row], acc);
606  }
607 }
608 
609 #undef SSE_Q8_BLOCK
610 #endif
611 
612 /**
613  * @brief Auto-dispatch GEMV for Q8_0 weights based on CPU features
614  *
615  * Dispatch priority (best available):
616  * 1. AVX-512 (512-bit vectors) - Intel Skylake-X+
617  * 2. AVX2+FMA (256-bit vectors) - Intel Haswell+
618  * 3. AVX (256-bit vectors) - Intel Sandy Bridge+
619  * 4. SSE4.1 (128-bit vectors) - Intel Nehalem+
620  * 5. Reference (scalar) - Fallback
621  *
622  * Uses ck_features.h for standardized feature detection.
623  *
624  * @param y Output vector [M]
625  * @param W Weight matrix in Q8_0 format [M x K]
626  * @param x Input vector [K]
627  * @param M Number of output rows
628  * @param K Number of input columns (hidden dimension)
629  */
630 void gemv_q8_0(float *y,
631  const void *W,
632  const float *x,
633  int M, int K)
634 {
635 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
636 #if defined(__AVX512F__)
637  gemv_q8_0_avx512(y, W, x, M, K);
638 #elif defined(__AVX2__)
639  gemv_q8_0_avx2(y, W, x, M, K);
640 #elif defined(__AVX__)
641  gemv_q8_0_avx(y, W, x, M, K);
642 #elif defined(__SSE4_1__)
643  gemv_q8_0_sse(y, W, x, M, K);
644 #else
645  gemv_q8_0_ref(y, W, x, M, K);
646 #endif
647 }
648 
649 /* ============================================================================
650  * Forward Pass: GEMM Y = W @ X
651  * ============================================================================ */
652 
653 /**
654  * @brief Matrix-matrix multiply with Q8_0 weights
655  */
656 void gemm_q8_0(float *Y,
657  const void *W,
658  const float *X,
659  int M, int N, int K)
660 {
661  for (int n = 0; n < N; n++) {
662  gemv_q8_0(&Y[n * M], W, &X[n * K], M, K);
663  }
664 }
665 
666 /* ============================================================================
667  * GEMM NT: C = A @ B^T + bias (B stored as N rows of K elements)
668  * ============================================================================ */
669 
670 /**
671  * @brief Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias
672  *
673  * @param A Input matrix [M x K], row-major FP32
674  * @param B Weight matrix in Q8_0 format, [N x K] stored row-major
675  * @param bias Optional bias [N], NULL if not used
676  * @param C Output [M x N], row-major FP32
677  * @param M Batch size (number of tokens)
678  * @param N Output dimension (number of rows in B)
679  * @param K Input dimension
680  */
681 void gemm_nt_q8_0(const float *A,
682  const void *B,
683  const float *bias,
684  float *C,
685  int M, int N, int K)
686 {
687  /* Use GEMV dispatch which selects AVX/SSE/scalar based on CPU */
688  for (int m = 0; m < M; m++) {
689  gemv_q8_0(&C[m * N], B, &A[m * K], N, K);
690  if (bias) {
691  for (int n = 0; n < N; n++) C[m * N + n] += bias[n];
692  }
693  }
694  return;
695 
696  const block_q8_0 *blocks = (const block_q8_0 *)B;
697  const int blocks_per_row = K / QK8_0;
698 
699  for (int m = 0; m < M; m++) {
700  const float *a_row = &A[m * K];
701 
702  for (int n = 0; n < N; n++) {
703  float sum = 0.0f;
704 
705  for (int b = 0; b < blocks_per_row; b++) {
706  const block_q8_0 *block = &blocks[n * blocks_per_row + b];
707  const float d = CK_FP16_TO_FP32(block->d);
708  const float *ap = &a_row[b * QK8_0];
709 
710  for (int i = 0; i < QK8_0; i++) {
711  sum += d * (float)block->qs[i] * ap[i];
712  }
713  }
714 
715  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
716  }
717  }
718 }
719 
720 /* ============================================================================
721  * Backward Pass: Gradient w.r.t. Input
722  * ============================================================================ */
723 
724 /**
725  * @brief Backward pass: compute input gradient (scalar reference)
726  *
727  * @param dX Output gradient w.r.t. input [K]
728  * @param W Weight matrix in Q8_0 format [M x K]
729  * @param dY Gradient w.r.t. output [M]
730  * @param M Number of output rows
731  * @param K Number of columns (input dimension)
732  */
733 void gemv_q8_0_backward_ref(float *dX,
734  const void *W,
735  const float *dY,
736  int M, int K)
737 {
738  const block_q8_0 *blocks = (const block_q8_0 *)W;
739  const int blocks_per_row = K / QK8_0;
740 
741  /* Zero output gradient */
742  memset(dX, 0, K * sizeof(float));
743 
744  /* Accumulate: dX += W^T @ dY */
745  for (int row = 0; row < M; row++) {
746  const float dy = dY[row];
747 
748  for (int b = 0; b < blocks_per_row; b++) {
749  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
750  const float d = CK_FP16_TO_FP32(block->d);
751  float *dxp = &dX[b * QK8_0];
752 
753  for (int i = 0; i < QK8_0; i++) {
754  dxp[i] += d * (float)block->qs[i] * dy;
755  }
756  }
757  }
758 }
759 
760 #ifdef __AVX512F__
761 /**
762  * @brief Backward pass with AVX-512
763  */
764 void gemv_q8_0_backward_avx512(float *dX,
765  const void *W,
766  const float *dY,
767  int M, int K)
768 {
769  const block_q8_0 *blocks = (const block_q8_0 *)W;
770  const int blocks_per_row = K / QK8_0;
771 
772  /* Zero output */
773  memset(dX, 0, K * sizeof(float));
774 
775  for (int row = 0; row < M; row++) {
776  const __m512 vdy = _mm512_set1_ps(dY[row]);
777 
778  for (int b = 0; b < blocks_per_row; b++) {
779  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
780  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
781  float *dxp = &dX[b * QK8_0];
782 
783  /* Process 32 weights in two batches of 16 */
784  for (int chunk = 0; chunk < 2; chunk++) {
785  /* Load and dequantize weights */
786  __m128i q8 = _mm_loadu_si128((const __m128i *)&block->qs[chunk * 16]);
787  __m512i q32 = _mm512_cvtepi8_epi32(q8);
788  __m512 w = _mm512_mul_ps(_mm512_cvtepi32_ps(q32), vscale);
789 
790  /* Compute gradient */
791  __m512 grad = _mm512_mul_ps(w, vdy);
792 
793  /* Accumulate */
794  __m512 dx_cur = _mm512_loadu_ps(&dxp[chunk * 16]);
795  _mm512_storeu_ps(&dxp[chunk * 16], _mm512_add_ps(dx_cur, grad));
796  }
797  }
798  }
799 }
800 #endif
801 
802 /**
803  * @brief Auto-dispatch backward
804  */
805 void gemv_q8_0_backward(float *dX,
806  const void *W,
807  const float *dY,
808  int M, int K)
809 {
810 #ifdef __AVX512F__
811  gemv_q8_0_backward_avx512(dX, W, dY, M, K);
812 #else
813  gemv_q8_0_backward_ref(dX, W, dY, M, K);
814 #endif
815 }
816 
817 /**
818  * @brief Batched backward pass
819  */
820 void gemm_q8_0_backward(float *dX,
821  const void *W,
822  const float *dY,
823  int M, int N, int K)
824 {
825  for (int n = 0; n < N; n++) {
826  gemv_q8_0_backward(&dX[n * K], W, &dY[n * M], M, K);
827  }
828 }
829 
830 /* ============================================================================
831  * Dot Product Utility
832  * ============================================================================ */
833 
834 float dot_q8_0(const void *w_q8_0, const float *x, int K)
835 {
836  float result;
837  gemv_q8_0(&result, w_q8_0, x, 1, K);
838  return result;
839 }
840 
841 /* ============================================================================
842  * Quantized Dot Product: Q8_0 x Q8_0
843  *
844  * This matches llama.cpp's ggml_vec_dot_q8_0_q8_0 exactly.
845  * Both weights and input are in Q8_0 format, enabling pure integer dot products.
846  * Result: sum_blocks( (d_w * d_x) * sum_weights( w8 * x8 ) )
847  *
848  * Key difference from gemv_q8_0:
849  * - gemv_q8_0: Takes FP32 input, dequantizes weights to FP32, FP32 dot
850  * - vec_dot_q8_0_q8_0: Takes Q8_0 input, does integer dot, scales at end
851  *
852  * The quantized path is faster and matches llama.cpp for parity testing.
853  * ============================================================================ */
854 
855 /**
856  * @brief Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)
857  *
858  * @param n Number of elements (must be multiple of 32)
859  * @param s Output: scalar dot product result
860  * @param vx Q8_0 quantized weights
861  * @param vy Q8_0 quantized input
862  */
863 void vec_dot_q8_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
864 {
865  const int qk = QK8_0; /* 32 */
866  const int nb = n / qk;
867 
868  const block_q8_0 *x = (const block_q8_0 *)vx;
869  const block_q8_0 *y = (const block_q8_0 *)vy;
870 
871  float sumf = 0.0f;
872 
873  for (int ib = 0; ib < nb; ib++) {
874  int sumi = 0;
875 
876  for (int j = 0; j < qk; j++) {
877  sumi += x[ib].qs[j] * y[ib].qs[j];
878  }
879 
880  sumf += sumi * (CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d));
881  }
882 
883  *s = sumf;
884 }
885 
886 #ifdef __AVX512F__
887 /**
888  * @brief Quantized dot product Q8_0 x Q8_0 (AVX-512)
889  */
890 void vec_dot_q8_0_q8_0_avx512(int n, float *s, const void *vx, const void *vy)
891 {
892  const int qk = QK8_0;
893  const int nb = n / qk;
894 
895  const block_q8_0 *x = (const block_q8_0 *)vx;
896  const block_q8_0 *y = (const block_q8_0 *)vy;
897 
898  float sumf = 0.0f;
899 
900  for (int ib = 0; ib < nb; ib++) {
901  const float d = CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d);
902 
903  /* Load 32 int8 weights in two batches of 16 */
904  __m128i x8_lo = _mm_loadu_si128((const __m128i *)&x[ib].qs[0]);
905  __m128i x8_hi = _mm_loadu_si128((const __m128i *)&x[ib].qs[16]);
906  __m128i y8_lo = _mm_loadu_si128((const __m128i *)&y[ib].qs[0]);
907  __m128i y8_hi = _mm_loadu_si128((const __m128i *)&y[ib].qs[16]);
908 
909  /* Sign-extend to 32-bit */
910  __m512i x32_lo = _mm512_cvtepi8_epi32(x8_lo);
911  __m512i x32_hi = _mm512_cvtepi8_epi32(x8_hi);
912  __m512i y32_lo = _mm512_cvtepi8_epi32(y8_lo);
913  __m512i y32_hi = _mm512_cvtepi8_epi32(y8_hi);
914 
915  /* Integer multiply */
916  __m512i prod_lo = _mm512_mullo_epi32(x32_lo, y32_lo);
917  __m512i prod_hi = _mm512_mullo_epi32(x32_hi, y32_hi);
918 
919  /* Sum all products */
920  int sumi = _mm512_reduce_add_epi32(_mm512_add_epi32(prod_lo, prod_hi));
921 
922  /* Scale and accumulate - use scalar to avoid 16x broadcast bug */
923  sumf += d * (float)sumi;
924  }
925 
926  *s = sumf;
927 }
928 #endif
929 
930 #if defined(__AVX__) && !defined(__AVX512F__)
931 /**
932  * @brief Quantized dot product Q8_0 x Q8_0 (AVX + SSE)
933  */
934 void vec_dot_q8_0_q8_0_avx(int n, float *s, const void *vx, const void *vy)
935 {
936  const int qk = QK8_0;
937  const int nb = n / qk;
938 
939  const block_q8_0 *x = (const block_q8_0 *)vx;
940  const block_q8_0 *y = (const block_q8_0 *)vy;
941 
942  float sumf = 0.0f;
943 
944  for (int ib = 0; ib < nb; ib++) {
945  const float d = CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d);
946 
947  int sumi = 0;
948 
949  /* Simple loop - compiler should vectorize */
950  for (int j = 0; j < qk; j++) {
951  sumi += x[ib].qs[j] * y[ib].qs[j];
952  }
953 
954  sumf += d * (float)sumi;
955  }
956 
957  *s = sumf;
958 }
959 #endif
960 
961 #if defined(__SSE4_1__) && !defined(__AVX__)
962 /**
963  * @brief Quantized dot product Q8_0 x Q8_0 (SSE4.1)
964  */
965 void vec_dot_q8_0_q8_0_sse(int n, float *s, const void *vx, const void *vy)
966 {
967  const int qk = QK8_0;
968  const int nb = n / qk;
969 
970  const block_q8_0 *x = (const block_q8_0 *)vx;
971  const block_q8_0 *y = (const block_q8_0 *)vy;
972 
973  float sumf = 0.0f;
974 
975  for (int ib = 0; ib < nb; ib++) {
976  const float d = CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d);
977 
978  __m128i acc_lo = _mm_setzero_si128();
979  __m128i acc_hi = _mm_setzero_si128();
980 
981  /* Process 32 elements in 4 groups of 8 */
982  for (int j = 0; j < 32; j += 8) {
983  /* Load 8 int8 values from each */
984  __m128i x8 = _mm_loadl_epi64((const __m128i *)&x[ib].qs[j]);
985  __m128i y8 = _mm_loadl_epi64((const __m128i *)&y[ib].qs[j]);
986 
987  /* Sign-extend to 16-bit */
988  __m128i x16 = _mm_cvtepi8_epi16(x8);
989  __m128i y16 = _mm_cvtepi8_epi16(y8);
990 
991  /* Multiply and add horizontally: (a0*b0 + a1*b1, a2*b2 + a3*b3, ...) */
992  __m128i prod = _mm_madd_epi16(x16, y16);
993 
994  /* Accumulate */
995  acc_lo = _mm_add_epi32(acc_lo, prod);
996  }
997 
998  /* Horizontal sum */
999  acc_lo = _mm_add_epi32(acc_lo, _mm_shuffle_epi32(acc_lo, _MM_SHUFFLE(1, 0, 3, 2)));
1000  acc_lo = _mm_add_epi32(acc_lo, _mm_shuffle_epi32(acc_lo, _MM_SHUFFLE(0, 1, 0, 1)));
1001  int sumi = _mm_extract_epi32(acc_lo, 0);
1002 
1003  sumf += d * (float)sumi;
1004  }
1005 
1006  *s = sumf;
1007 }
1008 #endif
1009 
1010 /**
1011  * @brief Auto-dispatch quantized dot product Q8_0 x Q8_0
1012  */
1013 void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
1014 {
1015 #ifdef __AVX512F__
1016  vec_dot_q8_0_q8_0_avx512(n, s, vx, vy);
1017 #elif defined(__AVX__)
1018  vec_dot_q8_0_q8_0_avx(n, s, vx, vy);
1019 #elif defined(__SSE4_1__)
1020  vec_dot_q8_0_q8_0_sse(n, s, vx, vy);
1021 #else
1022  vec_dot_q8_0_q8_0_ref(n, s, vx, vy);
1023 #endif
1024 }
1025 
1026 /* ============================================================================
1027  * Quantized GEMV: y = W @ x where W is Q8_0 and x is Q8_0
1028  *
1029  * This is the quantized equivalent of gemv_q8_0, but takes pre-quantized
1030  * input in Q8_0 format. Used for parity testing with llama.cpp.
1031  * ============================================================================ */
1032 
1033 /**
1034  * @brief Matrix-vector multiply with Q8_0 weights and Q8_0 input
1035  *
1036  * @param y Output vector [M]
1037  * @param W Weight matrix in Q8_0 format [M x K]
1038  * @param x_q8 Input vector in Q8_0 format [K]
1039  * @param M Number of output rows
1040  * @param K Number of columns (must be multiple of 32)
1041  */
1042 void gemv_q8_0_q8_0(float *y,
1043  const void *W,
1044  const void *x_q8,
1045  int M, int K)
1046 {
1047  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1048  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1049  const int blocks_per_row = K / QK8_0;
1050 
1051  for (int row = 0; row < M; row++) {
1052  vec_dot_q8_0_q8_0(K, &y[row],
1053  &w_blocks[row * blocks_per_row],
1054  x_blocks);
1055  }
1056 }
1057 
1058 /* ============================================================================
1059  * PARALLEL VERSIONS (for thread pool orchestration)
1060  *
1061  * These receive ith (thread index) and nth (total threads) from the
1062  * thread pool. OpenMP / pthreads live in the orchestration layer, NOT here.
1063  * ============================================================================ */
1064 
1065 /**
1066  * @brief Parallel reference GEMV for Q8_0 x Q8_0
1067  */
1069  const void *W,
1070  const void *x_q8,
1071  int M, int K,
1072  int ith, int nth)
1073 {
1074  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1075  if (ith < 0 || nth <= 0 || ith >= nth) return;
1076 
1077  const int dr = (M + nth - 1) / nth;
1078  const int r0 = dr * ith;
1079  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1080 
1081  if (r0 >= M) return;
1082 
1083  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1084  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1085  const int blocks_per_row = K / QK8_0;
1086 
1087  for (int row = r0; row < r1; row++) {
1088  vec_dot_q8_0_q8_0(K, &y[row],
1089  &w_blocks[row * blocks_per_row],
1090  x_blocks);
1091  }
1092 }
1093 
1094 /**
1095  * @brief Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching
1096  *
1097  * Each thread processes rows [r0, r1) where r0 = ith * ceil(M/nth).
1098  * Prefetches upcoming weight rows to hide memory latency.
1099  */
1101  const void *W,
1102  const void *x_q8,
1103  int M, int K,
1104  int ith, int nth)
1105 {
1106  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1107  if (ith < 0 || nth <= 0 || ith >= nth) return;
1108 
1109  const int dr = (M + nth - 1) / nth;
1110  const int r0 = dr * ith;
1111  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1112 
1113  if (r0 >= M) return;
1114 
1115  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1116  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1117  const int blocks_per_row = K / QK8_0;
1118 
1119 #if defined(__AVX__) || defined(__SSE4_1__)
1120  /* Prefetch first few rows */
1121  const int PREFETCH_ROWS = 4;
1122  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1123  const char *row_ptr = (const char *)(w_blocks + (r0 + p) * blocks_per_row);
1124  _mm_prefetch(row_ptr, _MM_HINT_T0);
1125  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1126  }
1127 
1128  for (int row = r0; row < r1; ++row) {
1129  /* Prefetch upcoming rows */
1130  if (row + PREFETCH_ROWS < r1) {
1131  const char *pf = (const char *)(w_blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1132  _mm_prefetch(pf, _MM_HINT_T0);
1133  _mm_prefetch(pf + 64, _MM_HINT_T0);
1134  }
1135 
1136  vec_dot_q8_0_q8_0(K, &y[row],
1137  &w_blocks[row * blocks_per_row],
1138  x_blocks);
1139  }
1140 #else
1141  /* Fallback: no prefetching */
1142  for (int row = r0; row < r1; row++) {
1143  vec_dot_q8_0_q8_0(K, &y[row],
1144  &w_blocks[row * blocks_per_row],
1145  x_blocks);
1146  }
1147 #endif
1148 }
1149 
1150 /**
1151  * @brief Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching
1152  */
1154  const void *W,
1155  const float *x,
1156  int M, int K,
1157  int ith, int nth)
1158 {
1159  if (!y || !W || !x || M <= 0 || K <= 0) return;
1160  if (ith < 0 || nth <= 0 || ith >= nth) return;
1161 
1162  const int dr = (M + nth - 1) / nth;
1163  const int r0 = dr * ith;
1164  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1165 
1166  if (r0 >= M) return;
1167 
1168  const block_q8_0 *blocks = (const block_q8_0 *)W;
1169  const int blocks_per_row = K / QK8_0;
1170 
1171 #if defined(__AVX__) || defined(__SSE4_1__)
1172  const int PREFETCH_ROWS = 4;
1173  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1174  const char *row_ptr = (const char *)(blocks + (r0 + p) * blocks_per_row);
1175  _mm_prefetch(row_ptr, _MM_HINT_T0);
1176  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1177  }
1178 
1179  for (int row = r0; row < r1; ++row) {
1180  if (row + PREFETCH_ROWS < r1) {
1181  const char *pf = (const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1182  _mm_prefetch(pf, _MM_HINT_T0);
1183  _mm_prefetch(pf + 64, _MM_HINT_T0);
1184  }
1185 
1186  /* Dispatch to best available SIMD for single row */
1187 #if defined(__AVX512F__)
1188  gemv_q8_0_avx512(&y[row],
1189  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1190  x, 1, K);
1191 #elif defined(__AVX2__)
1192  gemv_q8_0_avx2(&y[row],
1193  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1194  x, 1, K);
1195 #elif defined(__AVX__)
1196  gemv_q8_0_avx(&y[row],
1197  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1198  x, 1, K);
1199 #elif defined(__SSE4_1__)
1200  gemv_q8_0_sse(&y[row],
1201  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1202  x, 1, K);
1203 #else
1204  gemv_q8_0_ref(&y[row],
1205  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1206  x, 1, K);
1207 #endif
1208  }
1209 #else
1210  for (int row = r0; row < r1; row++) {
1211  gemv_q8_0_ref(&y[row],
1212  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1213  x, 1, K);
1214  }
1215 #endif
1216 }
CPU feature detection and dispatch macros.
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define CK_FP32_TO_FP16(x)
#define QK8_0
void gemv_q8_0_parallel_simd(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching.
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
void quantize_batch_q8_0(const float *x, void *vy, int num_rows, int k)
Batch quantize FP32 to Q8_0 format (row-major output)
void gemv_q8_0_q8_0_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching.
void quantize_batch_q8_k(const float *x, void *vy, int num_rows, int k)
Batch quantize FP32 to Q8_K format (row-major output)
void quantize_row_q8_k(const float *x, void *vy, int k)
void gemm_q8_0_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void vec_dot_q8_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)
void gemv_q8_0_q8_0_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q8_0 x Q8_0.
void gemm_q8_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q8_0 weights.
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
void gemv_q8_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
float dot_q8_0(const void *w_q8_0, const float *x, int K)
void gemv_q8_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q8_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q8_0 weights (scalar reference)
#define C(color)
Definition: show_config.c:39
int8_t qs[32]
int32_t id
Definition: tokenizer.h:315