← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q5_0.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q5_0.c
3  * @brief GEMM/GEMV kernels with Q5_0 quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Q5_0 Format:
15  * - 32 weights per block
16  * - 1 FP16 scale per block
17  * - Low 4-bits stored like Q4_0 (16 bytes)
18  * - High 1-bit packed separately (4 bytes)
19  * - 22 bytes per 32 weights = 5.5 bits/weight
20  *
21  * Dequantization: w = scale * (q5 - 16)
22  * where q5 = low4bit | (highbit << 4), giving values 0-31, then subtract 16 for signed -16 to +15
23  *
24  * Operations:
25  * Forward: Y = W @ X (W is Q5_0, X and Y are FP32)
26  * Backward: dX = W^T @ dY (gradient w.r.t. input)
27  */
28 
29 #include <stdint.h>
30 #include <stddef.h>
31 #include <string.h>
32 #include <stdio.h>
33 #include "ckernel_quant.h"
34 #include "ck_features.h"
35 
36 /* Include SIMD headers based on available extensions */
37 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
38 #include <immintrin.h>
39 #endif
40 
41 /* Forward declarations for dequant functions (defined in dequant_kernels.c) */
42 void dequant_q5_0_block(const block_q5_0 *block, float *output);
43 void dequant_q5_0_row(const void *src, float *dst, size_t n_elements);
44 
45 void gemm_nt_q5_0_sse_v2(const float *A,
46  const void *B,
47  const float *bias,
48  float *C,
49  int M, int N, int K);
50 
51 /* ============================================================================
52  * Forward Pass: GEMV y = W @ x
53  * ============================================================================ */
54 
55 /**
56  * @brief Matrix-vector multiply with Q5_0 weights (scalar reference)
57  *
58  * @param y Output vector [M]
59  * @param W Weight matrix in Q5_0 format [M x K]
60  * @param x Input vector [K]
61  * @param M Number of output rows
62  * @param K Number of columns (must be multiple of 32)
63  */
64 void gemv_q5_0_ref(float *y,
65  const void *W,
66  const float *x,
67  int M, int K)
68 {
69  const block_q5_0 *blocks = (const block_q5_0 *)W;
70  const int blocks_per_row = K / QK5_0;
71 
72  for (int row = 0; row < M; row++) {
73  float sum = 0.0f;
74 
75  for (int b = 0; b < blocks_per_row; b++) {
76  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
77  const float d = CK_FP16_TO_FP32(block->d);
78  const float *xp = &x[b * QK5_0];
79 
80  /* Get high bits as 32-bit integer */
81  uint32_t qh;
82  memcpy(&qh, block->qh, sizeof(qh));
83 
84  /* llama.cpp Q5_0 layout:
85  * - Weight j uses: low nibble of qs[j], high bit from qh bit j
86  * - Weight j+16 uses: high nibble of qs[j], high bit from qh bit (j+12)
87  * Note: j+12 not j+16 for the high bit of the second weight!
88  */
89  for (int j = 0; j < QK5_0 / 2; j++) {
90  const uint8_t packed = block->qs[j];
91 
92  /* Extract nibbles */
93  const int lo = (packed & 0x0F);
94  const int hi = (packed >> 4);
95 
96  /* Extract high bits - matches llama.cpp exactly */
97  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
98  const int xh_1 = ((qh >> (j + 12))) & 0x10;
99 
100  /* Combine to 5-bit signed value */
101  const int q0 = (lo | xh_0) - 16;
102  const int q1 = (hi | xh_1) - 16;
103 
104  /* Weights at indices j and j+16 */
105  sum += d * (float)q0 * xp[j];
106  sum += d * (float)q1 * xp[j + 16];
107  }
108  }
109 
110  y[row] = sum;
111  }
112 }
113 
114 #ifdef __AVX512F__
115 /**
116  * @brief Matrix-vector multiply with Q5_0 weights (AVX-512)
117  */
118 void gemv_q5_0_avx512(float *y,
119  const void *W,
120  const float *x,
121  int M, int K)
122 {
123  const block_q5_0 *blocks = (const block_q5_0 *)W;
124  const int blocks_per_row = K / QK5_0;
125  const __m512i offset = _mm512_set1_epi32(16);
126  const __m512i mask_lo = _mm512_set1_epi32(0x0F);
127  const __m512i one = _mm512_set1_epi32(1);
128 
129  for (int row = 0; row < M; row++) {
130  __m512 acc = _mm512_setzero_ps();
131 
132  for (int b = 0; b < blocks_per_row; b++) {
133  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
134  const __m512 vscale = _mm512_set1_ps(CK_FP16_TO_FP32(block->d));
135  const float *xp = &x[b * QK5_0];
136 
137  /* Load high bits */
138  uint32_t qh;
139  memcpy(&qh, block->qh, sizeof(qh));
140 
141  /* Load 16 bytes = 32 x 4-bit low weights */
142  __m128i packed = _mm_loadu_si128((const __m128i *)block->qs);
143  __m512i bytes = _mm512_cvtepu8_epi32(packed);
144 
145  /* Extract low nibbles */
146  __m512i lo = _mm512_and_epi32(bytes, mask_lo);
147  __m512i hi_shift = _mm512_srli_epi32(bytes, 4);
148 
149  /* llama.cpp Q5_0 layout:
150  * - Weights 0-15: high bits from qh bits 0-15
151  * - Weights 16-31: high bits from qh bits 12-27 (j+12 where j=0..15)
152  */
153  /* Build high bit contribution for first 16 weights (indices 0-15) */
154  __m512i qh_lo = _mm512_set_epi32(
155  ((qh >> 15) & 1) << 4, ((qh >> 14) & 1) << 4,
156  ((qh >> 13) & 1) << 4, ((qh >> 12) & 1) << 4,
157  ((qh >> 11) & 1) << 4, ((qh >> 10) & 1) << 4,
158  ((qh >> 9) & 1) << 4, ((qh >> 8) & 1) << 4,
159  ((qh >> 7) & 1) << 4, ((qh >> 6) & 1) << 4,
160  ((qh >> 5) & 1) << 4, ((qh >> 4) & 1) << 4,
161  ((qh >> 3) & 1) << 4, ((qh >> 2) & 1) << 4,
162  ((qh >> 1) & 1) << 4, ((qh >> 0) & 1) << 4
163  );
164 
165  /* Build high bit contribution for second 16 weights (indices 16-31)
166  * Scalar code: xh_1 = ((qh >> (j + 12))) & 0x10 extracts bit (j+16)
167  * (since 0x10 = bit 4, position is j+12+4 = j+16)
168  * So weights 16-31 use qh bits 16-31 */
169  __m512i qh_hi = _mm512_set_epi32(
170  ((qh >> 31) & 1) << 4, ((qh >> 30) & 1) << 4,
171  ((qh >> 29) & 1) << 4, ((qh >> 28) & 1) << 4,
172  ((qh >> 27) & 1) << 4, ((qh >> 26) & 1) << 4,
173  ((qh >> 25) & 1) << 4, ((qh >> 24) & 1) << 4,
174  ((qh >> 23) & 1) << 4, ((qh >> 22) & 1) << 4,
175  ((qh >> 21) & 1) << 4, ((qh >> 20) & 1) << 4,
176  ((qh >> 19) & 1) << 4, ((qh >> 18) & 1) << 4,
177  ((qh >> 17) & 1) << 4, ((qh >> 16) & 1) << 4
178  );
179 
180  /* Combine low + high bits and subtract offset */
181  __m512i q_lo = _mm512_sub_epi32(_mm512_or_epi32(lo, qh_lo), offset);
182  __m512i q_hi = _mm512_sub_epi32(_mm512_or_epi32(hi_shift, qh_hi), offset);
183 
184  /* Dequantize */
185  __m512 w_lo = _mm512_mul_ps(_mm512_cvtepi32_ps(q_lo), vscale);
186  __m512 w_hi = _mm512_mul_ps(_mm512_cvtepi32_ps(q_hi), vscale);
187 
188  /* Load sequential input: x[0-15] and x[16-31] */
189  __m512 x_first = _mm512_loadu_ps(&xp[0]); /* x[0..15] */
190  __m512 x_second = _mm512_loadu_ps(&xp[16]); /* x[16..31] */
191 
192  acc = _mm512_fmadd_ps(w_lo, x_first, acc);
193  acc = _mm512_fmadd_ps(w_hi, x_second, acc);
194  }
195 
196  y[row] = _mm512_reduce_add_ps(acc);
197  }
198 }
199 #endif
200 
201 /* ============================================================================
202  * AVX2 Implementation (Haswell+, 256-bit integer operations)
203  *
204  * AVX2 provides true 256-bit integer operations that AVX lacks.
205  * This implementation uses:
206  * - _mm256_cvtepi8_epi32: Sign-extend 8 bytes to 8 int32s (AVX2)
207  * - _mm256_fmadd_ps: Fused multiply-add (FMA3, available with AVX2)
208  * - 256-bit integer shuffles and masks
209  *
210  * Processing: 32 weights per block, 8 at a time with AVX2 registers
211  * ============================================================================ */
212 
213 #if defined(__AVX2__) && !defined(__AVX512F__)
214 
215 /* Helper: AVX2 horizontal sum of 8 floats */
216 static inline float hsum_avx2(__m256 v) {
217  __m128 lo = _mm256_castps256_ps128(v);
218  __m128 hi = _mm256_extractf128_ps(v, 1);
219  lo = _mm_add_ps(lo, hi); /* 4 floats */
220  __m128 shuf = _mm_shuffle_ps(lo, lo, _MM_SHUFFLE(2, 3, 0, 1));
221  __m128 sums = _mm_add_ps(lo, shuf);
222  shuf = _mm_movehl_ps(shuf, sums);
223  sums = _mm_add_ss(sums, shuf);
224  return _mm_cvtss_f32(sums);
225 }
226 
227 /**
228  * @brief Matrix-vector multiply with Q5_0 weights (AVX2 optimized)
229  *
230  * Uses AVX2 256-bit integer operations for efficient dequantization.
231  * Processes 8 weights at a time with full 256-bit registers.
232  */
233 void gemv_q5_0_avx2(float *y,
234  const void *W,
235  const float *x,
236  int M, int K)
237 {
238  const block_q5_0 *blocks = (const block_q5_0 *)W;
239  const int blocks_per_row = K / QK5_0; /* QK5_0 = 32 */
240 
241  for (int row = 0; row < M; row++) {
242  __m256 acc = _mm256_setzero_ps();
243 
244  for (int b = 0; b < blocks_per_row; b++) {
245  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
246  const float d = CK_FP16_TO_FP32(block->d);
247  const __m256 vscale = _mm256_set1_ps(d);
248  const float *xp = &x[b * QK5_0];
249 
250  /* Get high bits as 32-bit integer */
251  uint32_t qh;
252  memcpy(&qh, block->qh, sizeof(qh));
253 
254  /* Q5_0 layout: 32 weights per block
255  * - Weights 0-15: low nibbles of qs[0-15], high bit from qh[0-15]
256  * - Weights 16-31: high nibbles of qs[0-15], high bit from qh[16-31]
257  *
258  * Process in 4 groups of 8 for AVX2:
259  */
260 
261  /* Group 0: weights 0-7 (low nibbles of qs[0-7], high bits qh[0-7]) */
262  {
263  __m128i qs8 = _mm_loadl_epi64((const __m128i *)block->qs);
264  __m128i lo = _mm_and_si128(qs8, _mm_set1_epi8(0x0F));
265 
266  /* Build high bits for weights 0-7 */
267  int8_t hb[8];
268  for (int i = 0; i < 8; i++) {
269  hb[i] = ((qh >> i) << 4) & 0x10;
270  }
271  __m128i hi = _mm_loadl_epi64((const __m128i *)hb);
272 
273  /* Combine and subtract offset to get signed values */
274  __m128i q5 = _mm_or_si128(lo, hi);
275  __m128i offset = _mm_set1_epi8(16);
276  __m128i q5_signed = _mm_sub_epi8(q5, offset);
277 
278  /* Sign-extend to 32-bit and convert to float (AVX2) */
279  __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
280  __m256 wf = _mm256_cvtepi32_ps(q32);
281  wf = _mm256_mul_ps(wf, vscale);
282 
283  /* Load input and accumulate */
284  __m256 xv = _mm256_loadu_ps(&xp[0]);
285  acc = _mm256_fmadd_ps(wf, xv, acc);
286  }
287 
288  /* Group 1: weights 8-15 (low nibbles of qs[8-15], high bits qh[8-15]) */
289  {
290  __m128i qs8 = _mm_loadl_epi64((const __m128i *)(block->qs + 8));
291  __m128i lo = _mm_and_si128(qs8, _mm_set1_epi8(0x0F));
292 
293  int8_t hb[8];
294  for (int i = 0; i < 8; i++) {
295  hb[i] = ((qh >> (8 + i)) << 4) & 0x10;
296  }
297  __m128i hi = _mm_loadl_epi64((const __m128i *)hb);
298 
299  __m128i q5 = _mm_or_si128(lo, hi);
300  __m128i offset = _mm_set1_epi8(16);
301  __m128i q5_signed = _mm_sub_epi8(q5, offset);
302 
303  __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
304  __m256 wf = _mm256_cvtepi32_ps(q32);
305  wf = _mm256_mul_ps(wf, vscale);
306 
307  __m256 xv = _mm256_loadu_ps(&xp[8]);
308  acc = _mm256_fmadd_ps(wf, xv, acc);
309  }
310 
311  /* Group 2: weights 16-23 (high nibbles of qs[0-7], high bits qh[16-23]) */
312  {
313  __m128i qs8 = _mm_loadl_epi64((const __m128i *)block->qs);
314  __m128i hi_nib = _mm_and_si128(_mm_srli_epi16(qs8, 4), _mm_set1_epi8(0x0F));
315 
316  /* High bits for weights 16-23 come from qh bits 16-23 */
317  int8_t hb[8];
318  for (int i = 0; i < 8; i++) {
319  hb[i] = ((qh >> (16 + i)) & 1) << 4;
320  }
321  __m128i hi = _mm_loadl_epi64((const __m128i *)hb);
322 
323  __m128i q5 = _mm_or_si128(hi_nib, hi);
324  __m128i offset = _mm_set1_epi8(16);
325  __m128i q5_signed = _mm_sub_epi8(q5, offset);
326 
327  __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
328  __m256 wf = _mm256_cvtepi32_ps(q32);
329  wf = _mm256_mul_ps(wf, vscale);
330 
331  __m256 xv = _mm256_loadu_ps(&xp[16]);
332  acc = _mm256_fmadd_ps(wf, xv, acc);
333  }
334 
335  /* Group 3: weights 24-31 (high nibbles of qs[8-15], high bits qh[24-31]) */
336  {
337  __m128i qs8 = _mm_loadl_epi64((const __m128i *)(block->qs + 8));
338  __m128i hi_nib = _mm_and_si128(_mm_srli_epi16(qs8, 4), _mm_set1_epi8(0x0F));
339 
340  int8_t hb[8];
341  for (int i = 0; i < 8; i++) {
342  hb[i] = ((qh >> (24 + i)) & 1) << 4;
343  }
344  __m128i hi = _mm_loadl_epi64((const __m128i *)hb);
345 
346  __m128i q5 = _mm_or_si128(hi_nib, hi);
347  __m128i offset = _mm_set1_epi8(16);
348  __m128i q5_signed = _mm_sub_epi8(q5, offset);
349 
350  __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
351  __m256 wf = _mm256_cvtepi32_ps(q32);
352  wf = _mm256_mul_ps(wf, vscale);
353 
354  __m256 xv = _mm256_loadu_ps(&xp[24]);
355  acc = _mm256_fmadd_ps(wf, xv, acc);
356  }
357  }
358 
359  y[row] = hsum_avx2(acc);
360  }
361 }
362 #endif /* __AVX2__ && !__AVX512F__ */
363 
364 /* ============================================================================
365  * AVX Implementation with True SIMD Dequantization
366  *
367  * Q5_0 format: 32 weights per block
368  * - d: FP16 scale
369  * - qh: 4 bytes (32 high bits, one per weight)
370  * - qs: 16 bytes (low 4 bits, packed as pairs)
371  * - Dequant: w = d * ((lo | (highbit << 4)) - 16)
372  *
373  * This uses SSE for integer unpacking (Ivy Bridge doesn't have AVX2 for
374  * 256-bit integer ops) and AVX for float accumulation.
375  *
376  * Key optimization: Instead of scalar dequant, we use SIMD to:
377  * 1. Extract nibbles to bytes using SSE shuffle/shift
378  * 2. Combine with high bits using SSE or/and
379  * 3. Convert to float and scale
380  * ============================================================================ */
381 
382 #if defined(__AVX__) && !defined(__AVX2__) && !defined(__AVX512F__)
383 
384 /* Helper: Extract low nibbles from 16 packed bytes to 16 bytes */
385 static inline __m128i extract_low_nibbles(__m128i packed) {
386  return _mm_and_si128(packed, _mm_set1_epi8(0x0F));
387 }
388 
389 /* Helper: Extract high nibbles from 16 packed bytes to 16 bytes */
390 static inline __m128i extract_high_nibbles(__m128i packed) {
391  return _mm_and_si128(_mm_srli_epi16(packed, 4), _mm_set1_epi8(0x0F));
392 }
393 
394 /* Helper: SSE horizontal sum of 4 floats */
395 static inline float hsum_sse(__m128 v) {
396  __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1));
397  __m128 sums = _mm_add_ps(v, shuf);
398  shuf = _mm_movehl_ps(shuf, sums);
399  sums = _mm_add_ss(sums, shuf);
400  return _mm_cvtss_f32(sums);
401 }
402 
403 /* Helper: SSE dot product of 8 int8 values with 8 float values */
404 static inline float dot_int8_float8_sse(__m128i q8_lo, const float *x, float scale) {
405  /* Sign-extend 8 int8s to 8 int32s (in two steps) */
406  __m128i lo16 = _mm_cvtepi8_epi16(q8_lo); /* 8 int8 -> 8 int16 */
407  __m128i lo32_0 = _mm_cvtepi16_epi32(lo16); /* low 4 int16 -> 4 int32 */
408  __m128i lo32_1 = _mm_cvtepi16_epi32(_mm_srli_si128(lo16, 8)); /* high 4 int16 -> 4 int32 */
409 
410  /* Convert to float */
411  __m128 w0 = _mm_cvtepi32_ps(lo32_0);
412  __m128 w1 = _mm_cvtepi32_ps(lo32_1);
413 
414  /* Scale */
415  __m128 vscale = _mm_set1_ps(scale);
416  w0 = _mm_mul_ps(w0, vscale);
417  w1 = _mm_mul_ps(w1, vscale);
418 
419  /* Load input and multiply */
420  __m128 x0 = _mm_loadu_ps(x);
421  __m128 x1 = _mm_loadu_ps(x + 4);
422 
423  __m128 prod0 = _mm_mul_ps(w0, x0);
424  __m128 prod1 = _mm_mul_ps(w1, x1);
425 
426  /* Sum */
427  __m128 sum = _mm_add_ps(prod0, prod1);
428  return hsum_sse(sum);
429 }
430 
431 /**
432  * @brief Matrix-vector multiply with Q5_0 weights (AVX + SSE optimized)
433  *
434  * Uses SSE for integer dequantization, AVX for float accumulation.
435  * ~3-5x faster than scalar reference on Ivy Bridge.
436  */
437 void gemv_q5_0_avx(float *y,
438  const void *W,
439  const float *x,
440  int M, int K)
441 {
442  const block_q5_0 *blocks = (const block_q5_0 *)W;
443  const int blocks_per_row = K / QK5_0; /* QK5_0 = 32 */
444 
445  const __m128i mask_0f = _mm_set1_epi8(0x0F);
446  const __m128i mask_10 = _mm_set1_epi8(0x10);
447 
448  for (int row = 0; row < M; row++) {
449  float sum = 0.0f;
450 
451  for (int b = 0; b < blocks_per_row; b++) {
452  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
453  const float d = CK_FP16_TO_FP32(block->d);
454  const float *xp = &x[b * QK5_0];
455 
456  /* Load 16 packed bytes (32 nibbles) */
457  __m128i qs = _mm_loadu_si128((const __m128i *)block->qs);
458 
459  /* Extract low and high nibbles */
460  __m128i lo_nibbles = _mm_and_si128(qs, mask_0f); /* 16 low nibbles */
461  __m128i hi_nibbles = _mm_and_si128(_mm_srli_epi16(qs, 4), mask_0f); /* 16 high nibbles */
462 
463  /* Get high bits as 32-bit integer */
464  uint32_t qh;
465  memcpy(&qh, block->qh, sizeof(qh));
466 
467  /* Q5_0 layout: weight j uses qs[j/2] nibble (low if j<16, high if j>=16)
468  * plus high bit from qh:
469  * - weights 0-15: low nibbles of qs[0-15], high bit at qh[0-15]
470  * - weights 16-31: high nibbles of qs[0-15], high bit at qh[12-27]
471  *
472  * For efficiency, we process 32 weights in 4 groups of 8:
473  */
474 
475  /* Group 0: weights 0-7 (low nibbles of qs[0-7], high bits from qh[0-7]) */
476  {
477  uint8_t w8[8];
478  for (int i = 0; i < 8; i++) {
479  int lo = block->qs[i] & 0x0F;
480  int hb = ((qh >> i) << 4) & 0x10;
481  w8[i] = (lo | hb) - 16; /* Signed -16 to +15 */
482  }
483  __m128i q8 = _mm_loadl_epi64((const __m128i *)w8);
484  sum += dot_int8_float8_sse(q8, &xp[0], d);
485  }
486 
487  /* Group 1: weights 8-15 (low nibbles of qs[8-15], high bits from qh[8-15]) */
488  {
489  uint8_t w8[8];
490  for (int i = 0; i < 8; i++) {
491  int lo = block->qs[8 + i] & 0x0F;
492  int hb = ((qh >> (8 + i)) << 4) & 0x10;
493  w8[i] = (lo | hb) - 16;
494  }
495  __m128i q8 = _mm_loadl_epi64((const __m128i *)w8);
496  sum += dot_int8_float8_sse(q8, &xp[8], d);
497  }
498 
499  /* Group 2: weights 16-23 (high nibbles of qs[0-7], high bits from qh[12-19]) */
500  {
501  uint8_t w8[8];
502  for (int i = 0; i < 8; i++) {
503  int hi = block->qs[i] >> 4;
504  int hb = (qh >> (12 + i)) & 0x10;
505  w8[i] = (hi | hb) - 16;
506  }
507  __m128i q8 = _mm_loadl_epi64((const __m128i *)w8);
508  sum += dot_int8_float8_sse(q8, &xp[16], d);
509  }
510 
511  /* Group 3: weights 24-31 (high nibbles of qs[8-15], high bits from qh[20-27]) */
512  {
513  uint8_t w8[8];
514  for (int i = 0; i < 8; i++) {
515  int hi = block->qs[8 + i] >> 4;
516  int hb = (qh >> (20 + i)) & 0x10;
517  w8[i] = (hi | hb) - 16;
518  }
519  __m128i q8 = _mm_loadl_epi64((const __m128i *)w8);
520  sum += dot_int8_float8_sse(q8, &xp[24], d);
521  }
522  }
523 
524  y[row] = sum;
525  }
526 }
527 #endif /* __AVX__ && !__AVX512F__ */
528 
529 /**
530  * @brief Auto-dispatch GEMV for Q5_0 weights based on CPU features
531  *
532  * Dispatch priority (best available):
533  * 1. AVX-512 (512-bit vectors) - Intel Skylake-X+
534  * 2. AVX2+FMA (256-bit vectors) - Intel Haswell+
535  * 3. AVX (256-bit vectors) - Intel Sandy Bridge+
536  * 4. SSE4.1 (128-bit vectors) - Intel Nehalem+
537  * 5. Reference (scalar) - Fallback
538  *
539  * Uses ck_features.h for standardized feature detection.
540  *
541  * @param y Output vector [M]
542  * @param W Weight matrix in Q5_0 format [M x K]
543  * @param x Input vector [K]
544  * @param M Number of output rows
545  * @param K Number of input columns (hidden dimension)
546  */
547 void gemv_q5_0(float *y,
548  const void *W,
549  const float *x,
550  int M, int K)
551 {
552 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
553 #if defined(__AVX512F__)
554  gemv_q5_0_avx512(y, W, x, M, K);
555 #elif defined(__AVX2__)
556  gemv_q5_0_avx2(y, W, x, M, K);
557 #elif defined(__AVX__)
558  gemv_q5_0_avx(y, W, x, M, K);
559 #elif defined(__SSE4_1__)
560  gemv_q5_0_sse_v2(y, W, x, M, K);
561 #else
562  gemv_q5_0_ref(y, W, x, M, K);
563 #endif
564 }
565 
566 /* ============================================================================
567  * PARALLEL VERSIONS (for parallel orchestration)
568  *
569  * These receive ith (thread index) and nth (total threads) from orchestration.
570  * OpenMP lives in orchestration layer, NOT here.
571  * ============================================================================ */
572 
573 /**
574  * @brief Parallel reference GEMV for Q5_0 × FP32
575  */
576 void gemv_q5_0_parallel(float *y,
577  const void *W,
578  const float *x,
579  int M, int K,
580  int ith, int nth)
581 {
582  if (!y || !W || !x || M <= 0 || K <= 0) return;
583  if (ith < 0 || nth <= 0 || ith >= nth) return;
584 
585  const int dr = (M + nth - 1) / nth;
586  const int r0 = dr * ith;
587  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
588 
589  if (r0 >= M) return;
590 
591  const block_q5_0 *blocks = (const block_q5_0 *)W;
592  const int blocks_per_row = K / QK5_0;
593 
594  for (int row = r0; row < r1; row++) {
595  float sum = 0.0f;
596  for (int b = 0; b < blocks_per_row; b++) {
597  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
598  const float d = CK_FP16_TO_FP32(block->d);
599  const float *xp = &x[b * QK5_0];
600 
601  uint32_t qh;
602  memcpy(&qh, block->qh, sizeof(qh));
603 
604  for (int j = 0; j < QK5_0 / 2; j++) {
605  const uint8_t packed = block->qs[j];
606  const int lo = (packed & 0x0F);
607  const int hi = (packed >> 4);
608  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
609  const int xh_1 = ((qh >> (j + 12))) & 0x10;
610  const int w0 = (lo | xh_0) - 16;
611  const int w1 = (hi | xh_1) - 16;
612  sum += d * (w0 * xp[j] + w1 * xp[j + QK5_0/2]);
613  }
614  }
615  y[row] = sum;
616  }
617 }
618 
619 /**
620  * @brief Parallel SIMD GEMV for Q5_0 × FP32 with prefetching
621  */
623  const void *W,
624  const float *x,
625  int M, int K,
626  int ith, int nth)
627 {
628  if (!y || !W || !x || M <= 0 || K <= 0) return;
629  if (ith < 0 || nth <= 0 || ith >= nth) return;
630 
631  const int dr = (M + nth - 1) / nth;
632  const int r0 = dr * ith;
633  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
634 
635  if (r0 >= M) return;
636 
637  const block_q5_0 *blocks = (const block_q5_0 *)W;
638  const int blocks_per_row = K / QK5_0;
639 
640 #if defined(__AVX__) || defined(__SSE4_1__)
641  /* Prefetch first few rows */
642  const int PREFETCH_ROWS = 4;
643  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
644  const char *row_ptr = (const char *)(blocks + (r0 + p) * blocks_per_row);
645  _mm_prefetch(row_ptr, _MM_HINT_T0);
646  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
647  }
648 
649  for (int row = r0; row < r1; ++row) {
650  /* Prefetch rows ahead */
651  if (row + PREFETCH_ROWS < r1) {
652  const char *prefetch_ptr = (const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
653  _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
654  _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
655  }
656 
657  /* Use SIMD dot product for this row */
658 #if defined(__AVX512F__)
659  /* Call single-row AVX512 implementation */
660  gemv_q5_0_avx512(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
661 #elif defined(__AVX2__)
662  gemv_q5_0_avx2(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
663 #elif defined(__AVX__)
664  gemv_q5_0_avx(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
665 #else
666  gemv_q5_0_sse_v2(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
667 #endif
668  }
669 #else
670  /* Fallback to reference parallel */
671  gemv_q5_0_parallel(y, W, x, M, K, ith, nth);
672 #endif
673 }
674 
675 /* ============================================================================
676  * Forward Pass: GEMM Y = W @ X
677  * ============================================================================ */
678 
679 /**
680  * @brief Matrix-matrix multiply with Q5_0 weights
681  */
682 void gemm_q5_0(float *Y,
683  const void *W,
684  const float *X,
685  int M, int N, int K)
686 {
687  for (int n = 0; n < N; n++) {
688  gemv_q5_0(&Y[n * M], W, &X[n * K], M, K);
689  }
690 }
691 
692 /* ============================================================================
693  * Backward Pass: Gradient w.r.t. Input
694  * ============================================================================ */
695 
696 /**
697  * @brief Backward pass: compute input gradient
698  *
699  * @param dX Output gradient w.r.t. input [K]
700  * @param W Weight matrix in Q5_0 format [M x K]
701  * @param dY Gradient w.r.t. output [M]
702  * @param M Number of output rows
703  * @param K Number of columns (input dimension)
704  */
705 void gemv_q5_0_backward_ref(float *dX,
706  const void *W,
707  const float *dY,
708  int M, int K)
709 {
710  const block_q5_0 *blocks = (const block_q5_0 *)W;
711  const int blocks_per_row = K / QK5_0;
712 
713  /* Zero output gradient */
714  memset(dX, 0, K * sizeof(float));
715 
716  /* Accumulate: dX += W^T @ dY */
717  for (int row = 0; row < M; row++) {
718  const float dy = dY[row];
719 
720  for (int b = 0; b < blocks_per_row; b++) {
721  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
722  const float d = CK_FP16_TO_FP32(block->d);
723  float *dxp = &dX[b * QK5_0];
724 
725  /* Get high bits */
726  uint32_t qh;
727  memcpy(&qh, block->qh, sizeof(qh));
728 
729  /* llama.cpp Q5_0 layout - note j+12 for second weight high bit */
730  for (int j = 0; j < QK5_0 / 2; j++) {
731  const uint8_t packed = block->qs[j];
732 
733  /* Extract and reconstruct 5-bit values */
734  const int lo = (packed & 0x0F);
735  const int hi = (packed >> 4);
736  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
737  const int xh_1 = ((qh >> (j + 12))) & 0x10;
738  const int q0 = (lo | xh_0) - 16;
739  const int q1 = (hi | xh_1) - 16;
740 
741  dxp[j] += d * (float)q0 * dy;
742  dxp[j + 16] += d * (float)q1 * dy;
743  }
744  }
745  }
746 }
747 
748 /**
749  * @brief Auto-dispatch backward
750  */
751 void gemv_q5_0_backward(float *dX,
752  const void *W,
753  const float *dY,
754  int M, int K)
755 {
756  gemv_q5_0_backward_ref(dX, W, dY, M, K);
757 }
758 
759 /**
760  * @brief Batched backward pass
761  */
762 void gemm_q5_0_backward(float *dX,
763  const void *W,
764  const float *dY,
765  int M, int N, int K)
766 {
767  for (int n = 0; n < N; n++) {
768  gemv_q5_0_backward(&dX[n * K], W, &dY[n * M], M, K);
769  }
770 }
771 
772 /* ============================================================================
773  * GEMM NT (Non-Transpose A, Transpose B) - C = A @ B^T
774  * For inference: A is activations [M x K], B is weights [N x K]
775  * ============================================================================ */
776 
777 /**
778  * @brief GEMM with transposed Q5_0 weights: C = A @ B^T
779  *
780  * @param A Input activations [M x K], row-major FP32
781  * @param B Weight matrix in Q5_0 format [N x K], row-major quantized
782  * @param bias Optional bias [N], NULL if not used
783  * @param C Output [M x N], row-major FP32
784  * @param M Batch size (number of tokens)
785  * @param N Output dimension (number of rows in B)
786  * @param K Input dimension
787  */
788 void gemm_nt_q5_0_ref(const float *A,
789  const void *B,
790  const float *bias,
791  float *C,
792  int M, int N, int K)
793 {
794  const block_q5_0 *blocks = (const block_q5_0 *)B;
795  const int blocks_per_row = K / QK5_0;
796 
797  for (int m = 0; m < M; m++) {
798  const float *a_row = &A[m * K];
799 
800  for (int n = 0; n < N; n++) {
801  float sum = 0.0f;
802 
803  for (int b = 0; b < blocks_per_row; b++) {
804  const block_q5_0 *block = &blocks[n * blocks_per_row + b];
805  const float d = CK_FP16_TO_FP32(block->d);
806  const float *ap = &a_row[b * QK5_0];
807 
808  uint32_t qh;
809  memcpy(&qh, block->qh, sizeof(qh));
810 
811  /* llama.cpp Q5_0 layout - note j+12 for second weight high bit */
812  for (int j = 0; j < QK5_0 / 2; j++) {
813  const uint8_t packed = block->qs[j];
814  const int lo = (packed & 0x0F);
815  const int hi = (packed >> 4);
816  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
817  const int xh_1 = ((qh >> (j + 12))) & 0x10;
818  const int q0 = (lo | xh_0) - 16;
819  const int q1 = (hi | xh_1) - 16;
820 
821  sum += d * (float)q0 * ap[j];
822  sum += d * (float)q1 * ap[j + 16];
823  }
824  }
825 
826  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
827  }
828  }
829 }
830 
831 void gemm_nt_q5_0(const float *A,
832  const void *B,
833  const float *bias,
834  float *C,
835  int M, int N, int K)
836 {
837  /* For decode (M=1), use direct GEMV which has AVX optimization */
838  if (M == 1) {
839  /* gemm_q5_0 expects column-major output, but we need row-major
840  * So we call gemv_q5_0 directly for each output element */
841  gemv_q5_0(C, B, A, N, K);
842  if (bias) {
843  for (int n = 0; n < N; n++) {
844  C[n] += bias[n];
845  }
846  }
847  return;
848  }
849 
850  /* For prefill (M>1), use GEMM which dispatches to GEMV with AVX/AVX512 */
851  /* gemm_q5_0 produces Y as [batch x M_out]. Here:
852  * batch = M (tokens)
853  * M_out = N (output channels) */
854  gemm_q5_0(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
855 
856  if (bias) {
857  for (int m = 0; m < M; m++) {
858  float *row = C + (size_t)m * (size_t)N;
859  for (int n = 0; n < N; n++) {
860  row[n] += bias[n];
861  }
862  }
863  }
864 }
865 
866 /* ============================================================================
867  * Dot Product Utility
868  * ============================================================================ */
869 
870 float dot_q5_0(const void *w_q5_0, const float *x, int K)
871 {
872  float result;
873  gemv_q5_0(&result, w_q5_0, x, 1, K);
874  return result;
875 }
876 
877 /* ============================================================================
878  * Quantized Dot Product: Q5_0 x Q8_0
879  *
880  * This matches llama.cpp's ggml_vec_dot_q5_0_q8_0 exactly.
881  * Input is pre-quantized to Q8_0 format, enabling integer dot products.
882  * Result: sum_blocks( (d_w * d_x) * sum_weights( w5 * x8 ) )
883  *
884  * Key difference from gemv_q5_0:
885  * - gemv_q5_0: Takes FP32 input, dequantizes weights to FP32, FP32 dot
886  * - vec_dot_q5_0_q8_0: Takes Q8_0 input, does integer dot, scales at end
887  *
888  * The quantized path is faster and matches llama.cpp for parity testing.
889  * ============================================================================ */
890 
891 /**
892  * @brief Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)
893  *
894  * @param n Number of elements (must be multiple of 32)
895  * @param s Output: scalar dot product result
896  * @param vx Q5_0 quantized weights
897  * @param vy Q8_0 quantized input
898  */
899 void vec_dot_q5_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
900 {
901  const int qk = QK5_0; /* 32 */
902  const int nb = n / qk;
903 
904  const block_q5_0 *x = (const block_q5_0 *)vx;
905  const block_q8_0 *y = (const block_q8_0 *)vy;
906 
907  float sumf = 0.0f;
908 
909  for (int ib = 0; ib < nb; ib++) {
910  /* Load high bits for this block */
911  uint32_t qh;
912  memcpy(&qh, x[ib].qh, sizeof(qh));
913 
914  int sumi0 = 0;
915  int sumi1 = 0;
916 
917  for (int j = 0; j < qk / 2; j++) {
918  /* Extract high bits - matches llama.cpp exactly */
919  const uint8_t xh_0 = ((qh & (1u << (j + 0))) >> (j + 0)) << 4;
920  const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
921 
922  /* Reconstruct 5-bit signed values (-16 to +15) */
923  const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
924  const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
925 
926  /* Integer dot product with Q8_0 values */
927  sumi0 += x0 * y[ib].qs[j];
928  sumi1 += x1 * y[ib].qs[j + qk / 2];
929  }
930 
931  int sumi = sumi0 + sumi1;
932  sumf += (CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d)) * sumi;
933  }
934 
935  *s = sumf;
936 }
937 
938 #ifdef __AVX512F__
939 /**
940  * @brief Quantized dot product Q5_0 x Q8_0 (AVX-512)
941  */
942 void vec_dot_q5_0_q8_0_avx512(int n, float *s, const void *vx, const void *vy)
943 {
944  const int qk = QK5_0;
945  const int nb = n / qk;
946 
947  const block_q5_0 *x = (const block_q5_0 *)vx;
948  const block_q8_0 *y = (const block_q8_0 *)vy;
949 
950  float sumf = 0.0f;
951 
952  for (int ib = 0; ib < nb; ib++) {
953  const float d = CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d);
954 
955  /* Load high bits */
956  uint32_t qh;
957  memcpy(&qh, x[ib].qh, sizeof(qh));
958 
959  /* Load 16 packed bytes (32 nibbles) */
960  __m128i qs = _mm_loadu_si128((const __m128i *)x[ib].qs);
961 
962  /* Process first 16 weights (low nibbles, high bits 0-15) */
963  __m512i lo_nibbles = _mm512_cvtepu8_epi32(qs);
964  lo_nibbles = _mm512_and_epi32(lo_nibbles, _mm512_set1_epi32(0x0F));
965 
966  /* Build high bit contribution for first 16 weights */
967  __m512i qh_lo = _mm512_set_epi32(
968  ((qh >> 15) & 1) << 4, ((qh >> 14) & 1) << 4,
969  ((qh >> 13) & 1) << 4, ((qh >> 12) & 1) << 4,
970  ((qh >> 11) & 1) << 4, ((qh >> 10) & 1) << 4,
971  ((qh >> 9) & 1) << 4, ((qh >> 8) & 1) << 4,
972  ((qh >> 7) & 1) << 4, ((qh >> 6) & 1) << 4,
973  ((qh >> 5) & 1) << 4, ((qh >> 4) & 1) << 4,
974  ((qh >> 3) & 1) << 4, ((qh >> 2) & 1) << 4,
975  ((qh >> 1) & 1) << 4, ((qh >> 0) & 1) << 4
976  );
977 
978  /* Combine and subtract 16 to get signed values */
979  __m512i q5_lo = _mm512_sub_epi32(_mm512_or_epi32(lo_nibbles, qh_lo),
980  _mm512_set1_epi32(16));
981 
982  /* Load Q8_0 values for first 16 */
983  __m128i y8_lo = _mm_loadu_si128((const __m128i *)&y[ib].qs[0]);
984  __m512i y32_lo = _mm512_cvtepi8_epi32(y8_lo);
985 
986  /* Integer multiply and accumulate */
987  __m512i prod_lo = _mm512_mullo_epi32(q5_lo, y32_lo);
988 
989  /* Process second 16 weights (high nibbles, high bits 16-31 via j+12 mapping) */
990  __m512i hi_nibbles = _mm512_cvtepu8_epi32(qs);
991  hi_nibbles = _mm512_srli_epi32(hi_nibbles, 4);
992 
993  /* Build high bit contribution for second 16 weights (bits 16-31) */
994  __m512i qh_hi = _mm512_set_epi32(
995  ((qh >> 31) & 1) << 4, ((qh >> 30) & 1) << 4,
996  ((qh >> 29) & 1) << 4, ((qh >> 28) & 1) << 4,
997  ((qh >> 27) & 1) << 4, ((qh >> 26) & 1) << 4,
998  ((qh >> 25) & 1) << 4, ((qh >> 24) & 1) << 4,
999  ((qh >> 23) & 1) << 4, ((qh >> 22) & 1) << 4,
1000  ((qh >> 21) & 1) << 4, ((qh >> 20) & 1) << 4,
1001  ((qh >> 19) & 1) << 4, ((qh >> 18) & 1) << 4,
1002  ((qh >> 17) & 1) << 4, ((qh >> 16) & 1) << 4
1003  );
1004 
1005  __m512i q5_hi = _mm512_sub_epi32(_mm512_or_epi32(hi_nibbles, qh_hi),
1006  _mm512_set1_epi32(16));
1007 
1008  /* Load Q8_0 values for second 16 */
1009  __m128i y8_hi = _mm_loadu_si128((const __m128i *)&y[ib].qs[16]);
1010  __m512i y32_hi = _mm512_cvtepi8_epi32(y8_hi);
1011 
1012  __m512i prod_hi = _mm512_mullo_epi32(q5_hi, y32_hi);
1013 
1014  /* Sum all products */
1015  int sumi = _mm512_reduce_add_epi32(_mm512_add_epi32(prod_lo, prod_hi));
1016 
1017  /* Scale and accumulate - use scalar to avoid 16x broadcast bug */
1018  sumf += d * (float)sumi;
1019  }
1020 
1021  *s = sumf;
1022 }
1023 #endif
1024 
1025 #if defined(__SSSE3__)
1026 /**
1027  * @brief Spread 32 bits to 32 bytes { 0x00, 0xFF }
1028  * Adapted from llama.cpp bytes_from_bits_32 (AVX path)
1029  *
1030  * Uses shuffle to replicate each byte, then OR with bit_mask and compare.
1031  * Result: 0xFF where bit was set, 0x00 where bit was not set.
1032  */
1033 static inline void bytes_from_bits_32_sse(__m128i *out_lo, __m128i *out_hi, const uint8_t *qh)
1034 {
1035  uint32_t x32;
1036  memcpy(&x32, qh, sizeof(uint32_t));
1037 
1038  /* Shuffle masks: replicate byte j/8 of x32 to each position */
1039  const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101LL, 0x0000000000000000LL);
1040  const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303LL, 0x0202020202020202LL);
1041 
1042  __m128i bytes_lo = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
1043  __m128i bytes_hi = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
1044 
1045  /* Bit mask: pattern tests each bit position 0-7 within each byte.
1046  * 0x7fbfdfeff7fbfdfe in binary has bits 1,2,3,4,5,6,7,0 cleared per 8-byte cycle.
1047  * After OR, byte will be 0xFF if the corresponding bit was set. */
1048  const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL);
1049 
1050  bytes_lo = _mm_or_si128(bytes_lo, bit_mask);
1051  bytes_hi = _mm_or_si128(bytes_hi, bit_mask);
1052 
1053  /* Compare with all 1s: 0xFF if bit was set, 0x00 if not */
1054  *out_lo = _mm_cmpeq_epi8(bytes_lo, _mm_set1_epi64x(-1LL));
1055  *out_hi = _mm_cmpeq_epi8(bytes_hi, _mm_set1_epi64x(-1LL));
1056 }
1057 
1058 /**
1059  * @brief Multiply signed int8 vectors using sign trick
1060  * Adapted from llama.cpp mul_sum_i8_pairs
1061  *
1062  * Uses: abs(x) * sign(y,x) = x * y for signed multiplication with maddubs
1063  */
1064 static inline __m128i mul_sum_i8_pairs_sse(const __m128i x, const __m128i y)
1065 {
1066  const __m128i ax = _mm_sign_epi8(x, x); /* abs(x) */
1067  const __m128i sy = _mm_sign_epi8(y, x); /* y * sign(x) */
1068  const __m128i dot = _mm_maddubs_epi16(ax, sy); /* unsigned*signed pairs -> int16 */
1069  return _mm_madd_epi16(dot, _mm_set1_epi16(1)); /* sum pairs -> int32 */
1070 }
1071 
1072 /**
1073  * @brief Vectorized dot product Q5_0 x Q8_0 using SSSE3
1074  *
1075  * Based on llama.cpp ggml_vec_dot_q5_0_q8_0 AVX implementation.
1076  * Key insight: use shuffle-based bit spreading and sign trick.
1077  *
1078  * Q5_0 encoding: nibble | (high_bit ? 0 : 0xF0)
1079  * - When high bit SET: value = nibble (0-15, positive as signed)
1080  * - When high bit NOT SET: value = nibble | 0xF0 (negative as signed, -16 to -1)
1081  *
1082  * Sign trick handles signed*signed multiplication with unsigned*signed maddubs.
1083  */
1084 void vec_dot_q5_0_q8_0_sse(int n, float *s, const void *vx, const void *vy)
1085 {
1086  const int qk = QK5_0; /* 32 */
1087  const int nb = n / qk;
1088 
1089  const block_q5_0 *x = (const block_q5_0 *)vx;
1090  const block_q8_0 *y = (const block_q8_0 *)vy;
1091 
1092  float sumf = 0.0f;
1093 
1094  const __m128i mask_0f = _mm_set1_epi8(0x0F);
1095  const __m128i mask_f0 = _mm_set1_epi8((char)0xF0);
1096 
1097  for (int ib = 0; ib < nb; ib++) {
1098  const float d = CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d);
1099 
1100  /* Load 16 bytes of packed nibbles */
1101  __m128i qs = _mm_loadu_si128((const __m128i *)x[ib].qs);
1102 
1103  /* Extract nibbles: lo for indices 0-15, hi for indices 16-31 */
1104  __m128i bx_lo = _mm_and_si128(qs, mask_0f);
1105  __m128i bx_hi = _mm_and_si128(_mm_srli_epi16(qs, 4), mask_0f);
1106 
1107  /* Spread 32 high bits to 32 bytes (0xFF=set, 0x00=not set) */
1108  __m128i bxhi_lo, bxhi_hi;
1109  bytes_from_bits_32_sse(&bxhi_lo, &bxhi_hi, x[ib].qh);
1110 
1111  /* Apply encoding: (~bxhi) & 0xF0
1112  * When bit SET: bxhi=0xFF, result = 0x00 (value is positive 0-15)
1113  * When bit NOT SET: bxhi=0x00, result = 0xF0 (value is negative) */
1114  bxhi_lo = _mm_andnot_si128(bxhi_lo, mask_f0);
1115  bxhi_hi = _mm_andnot_si128(bxhi_hi, mask_f0);
1116 
1117  /* Combine: nibble | high_bit_contribution */
1118  bx_lo = _mm_or_si128(bx_lo, bxhi_lo);
1119  bx_hi = _mm_or_si128(bx_hi, bxhi_hi);
1120 
1121  /* Load Q8_0 values (32 signed int8) */
1122  __m128i by_lo = _mm_loadu_si128((const __m128i *)y[ib].qs);
1123  __m128i by_hi = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
1124 
1125  /* Multiply using sign trick and sum to int32 */
1126  __m128i p_lo = mul_sum_i8_pairs_sse(bx_lo, by_lo);
1127  __m128i p_hi = mul_sum_i8_pairs_sse(bx_hi, by_hi);
1128 
1129  /* Sum the two halves */
1130  __m128i sum = _mm_add_epi32(p_lo, p_hi);
1131 
1132  /* Horizontal sum of 4 int32 values (avoiding hadd for better perf) */
1133  __m128i hi64 = _mm_unpackhi_epi64(sum, sum);
1134  __m128i sum64 = _mm_add_epi32(hi64, sum);
1135  __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1136  int32_t sumi = _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
1137 
1138  /* Scale and accumulate */
1139  sumf += d * (float)sumi;
1140  }
1141 
1142  *s = sumf;
1143 }
1144 #endif
1145 
1146 #if defined(__AVX__) && !defined(__AVX512F__)
1147 
1148 /* Combine two __m128i into __m256i (AVX without AVX2) */
1149 #define MM256_SET_M128I(hi, lo) _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1)
1150 
1151 /**
1152  * @brief Spread 32 bits to 32 bytes using AVX
1153  * Returns __m256i with 0xFF where bit was set, 0x00 where not
1154  */
1155 static inline __m256i bytes_from_bits_32_avx(const uint8_t *qh)
1156 {
1157  uint32_t x32;
1158  memcpy(&x32, qh, sizeof(uint32_t));
1159 
1160  const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101LL, 0x0000000000000000LL);
1161  const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303LL, 0x0202020202020202LL);
1162 
1163  __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
1164  __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
1165 
1166  const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL);
1167 
1168  bytesl = _mm_or_si128(bytesl, bit_mask);
1169  bytesh = _mm_or_si128(bytesh, bit_mask);
1170 
1171  bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1LL));
1172  bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1LL));
1173 
1174  return MM256_SET_M128I(bytesh, bytesl);
1175 }
1176 
1177 /**
1178  * @brief Unpack 32 4-bit nibbles to 32 bytes using AVX
1179  */
1180 static inline __m256i bytes_from_nibbles_32_avx(const uint8_t *qs)
1181 {
1182  __m128i tmpl = _mm_loadu_si128((const __m128i *)qs);
1183  __m128i tmph = _mm_srli_epi16(tmpl, 4);
1184  const __m128i lowMask = _mm_set1_epi8(0x0F);
1185  tmpl = _mm_and_si128(lowMask, tmpl);
1186  tmph = _mm_and_si128(lowMask, tmph);
1187  return MM256_SET_M128I(tmph, tmpl);
1188 }
1189 
1190 /**
1191  * @brief Multiply signed int8 pairs and return as float vector (AVX)
1192  * Uses 128-bit ops internally but returns 256-bit float result
1193  */
1194 static inline __m256 mul_sum_i8_pairs_float_avx(const __m256i x, const __m256i y)
1195 {
1196  const __m128i xl = _mm256_castsi256_si128(x);
1197  const __m128i xh = _mm256_extractf128_si256(x, 1);
1198  const __m128i yl = _mm256_castsi256_si128(y);
1199  const __m128i yh = _mm256_extractf128_si256(y, 1);
1200 
1201  /* Get absolute values of x vectors */
1202  const __m128i axl = _mm_sign_epi8(xl, xl);
1203  const __m128i axh = _mm_sign_epi8(xh, xh);
1204  /* Sign the values of the y vectors */
1205  const __m128i syl = _mm_sign_epi8(yl, xl);
1206  const __m128i syh = _mm_sign_epi8(yh, xh);
1207 
1208  /* Perform multiplication and create 16-bit values */
1209  const __m128i dotl = _mm_maddubs_epi16(axl, syl);
1210  const __m128i doth = _mm_maddubs_epi16(axh, syh);
1211 
1212  /* Sum pairs to int32 */
1213  const __m128i ones = _mm_set1_epi16(1);
1214  const __m128i summed_pairsl = _mm_madd_epi16(ones, dotl);
1215  const __m128i summed_pairsh = _mm_madd_epi16(ones, doth);
1216 
1217  /* Convert to float */
1218  const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
1219  return _mm256_cvtepi32_ps(summed_pairs);
1220 }
1221 
1222 /**
1223  * @brief Horizontal sum of 8 floats (AVX)
1224  */
1225 static inline float hsum_float_8_avx(const __m256 x)
1226 {
1227  __m128 res = _mm256_extractf128_ps(x, 1);
1228  res = _mm_add_ps(res, _mm256_castps256_ps128(x));
1229  res = _mm_add_ps(res, _mm_movehl_ps(res, res));
1230  res = _mm_add_ss(res, _mm_movehdup_ps(res));
1231  return _mm_cvtss_f32(res);
1232 }
1233 
1234 /**
1235  * @brief Quantized dot product Q5_0 x Q8_0 (AVX) - Optimized with 2x unroll
1236  *
1237  * Based on llama.cpp ggml_vec_dot_q5_0_q8_0 AVX implementation.
1238  * Uses 256-bit accumulation and processes 32 values per block.
1239  *
1240  * Optimizations:
1241  * - 2x loop unrolling to reduce loop overhead
1242  * - Prefetching next blocks to hide memory latency
1243  * - Interleaved operations for better instruction-level parallelism
1244  */
1245 void vec_dot_q5_0_q8_0_avx(int n, float *s, const void *vx, const void *vy)
1246 {
1247  const int qk = QK5_0; /* 32 */
1248  const int nb = n / qk;
1249 
1250  const block_q5_0 *x = (const block_q5_0 *)vx;
1251  const block_q8_0 *y = (const block_q8_0 *)vy;
1252 
1253  __m256 acc0 = _mm256_setzero_ps();
1254  __m256 acc1 = _mm256_setzero_ps();
1255  const __m128i mask = _mm_set1_epi8((char)0xF0);
1256 
1257  /* Process 2 blocks per iteration */
1258  int ib = 0;
1259  for (; ib + 1 < nb; ib += 2) {
1260  /* Prefetch next blocks (2 cache lines ahead) */
1261  _mm_prefetch((const char *)&x[ib + 4], _MM_HINT_T0);
1262  _mm_prefetch((const char *)&y[ib + 4], _MM_HINT_T0);
1263 
1264  /* === Block 0 === */
1265  const __m256 d0 = _mm256_set1_ps(CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d));
1266 
1267  /* Unpack nibbles to 32 bytes */
1268  __m256i bx0 = bytes_from_nibbles_32_avx(x[ib].qs);
1269 
1270  /* Spread high bits */
1271  const __m256i bxhi0 = bytes_from_bits_32_avx(x[ib].qh);
1272  __m128i bxhil0 = _mm256_castsi256_si128(bxhi0);
1273  __m128i bxhih0 = _mm256_extractf128_si256(bxhi0, 1);
1274 
1275  /* === Block 1 === (start while block 0 is in flight) */
1276  const __m256 d1 = _mm256_set1_ps(CK_FP16_TO_FP32(x[ib+1].d) * CK_FP16_TO_FP32(y[ib+1].d));
1277 
1278  __m256i bx1 = bytes_from_nibbles_32_avx(x[ib+1].qs);
1279  const __m256i bxhi1 = bytes_from_bits_32_avx(x[ib+1].qh);
1280  __m128i bxhil1 = _mm256_castsi256_si128(bxhi1);
1281  __m128i bxhih1 = _mm256_extractf128_si256(bxhi1, 1);
1282 
1283  /* === Finish Block 0 === */
1284  bxhil0 = _mm_andnot_si128(bxhil0, mask);
1285  bxhih0 = _mm_andnot_si128(bxhih0, mask);
1286 
1287  __m128i bxl0 = _mm256_castsi256_si128(bx0);
1288  __m128i bxh0 = _mm256_extractf128_si256(bx0, 1);
1289  bxl0 = _mm_or_si128(bxl0, bxhil0);
1290  bxh0 = _mm_or_si128(bxh0, bxhih0);
1291  bx0 = MM256_SET_M128I(bxh0, bxl0);
1292 
1293  const __m256i by0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
1294  const __m256 q0 = mul_sum_i8_pairs_float_avx(bx0, by0);
1295  acc0 = _mm256_add_ps(_mm256_mul_ps(d0, q0), acc0);
1296 
1297  /* === Finish Block 1 === */
1298  bxhil1 = _mm_andnot_si128(bxhil1, mask);
1299  bxhih1 = _mm_andnot_si128(bxhih1, mask);
1300 
1301  __m128i bxl1 = _mm256_castsi256_si128(bx1);
1302  __m128i bxh1 = _mm256_extractf128_si256(bx1, 1);
1303  bxl1 = _mm_or_si128(bxl1, bxhil1);
1304  bxh1 = _mm_or_si128(bxh1, bxhih1);
1305  bx1 = MM256_SET_M128I(bxh1, bxl1);
1306 
1307  const __m256i by1 = _mm256_loadu_si256((const __m256i *)y[ib+1].qs);
1308  const __m256 q1 = mul_sum_i8_pairs_float_avx(bx1, by1);
1309  acc1 = _mm256_add_ps(_mm256_mul_ps(d1, q1), acc1);
1310  }
1311 
1312  /* Handle remaining block if nb is odd */
1313  for (; ib < nb; ib++) {
1314  const __m256 d = _mm256_set1_ps(CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d));
1315 
1316  __m256i bx_0 = bytes_from_nibbles_32_avx(x[ib].qs);
1317  const __m256i bxhi = bytes_from_bits_32_avx(x[ib].qh);
1318  __m128i bxhil = _mm256_castsi256_si128(bxhi);
1319  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1320 
1321  bxhil = _mm_andnot_si128(bxhil, mask);
1322  bxhih = _mm_andnot_si128(bxhih, mask);
1323 
1324  __m128i bxl = _mm256_castsi256_si128(bx_0);
1325  __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
1326  bxl = _mm_or_si128(bxl, bxhil);
1327  bxh = _mm_or_si128(bxh, bxhih);
1328  bx_0 = MM256_SET_M128I(bxh, bxl);
1329 
1330  const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
1331  const __m256 q = mul_sum_i8_pairs_float_avx(bx_0, by_0);
1332  acc0 = _mm256_add_ps(_mm256_mul_ps(d, q), acc0);
1333  }
1334 
1335  /* Combine accumulators and sum */
1336  acc0 = _mm256_add_ps(acc0, acc1);
1337  *s = hsum_float_8_avx(acc0);
1338 }
1339 
1340 /* ============================================================================
1341  * Unrolled GEMM: C[M,N] = A_q8[M,K] @ B_q5[N,K]^T + bias
1342  *
1343  * Key optimizations over the naive vec_dot-per-element approach:
1344  * 1. N-dimension 2x unroll: loads activation block ONCE, reuses for 2 weight rows
1345  * 2. Inlined Q5_0 unpack: eliminates function call overhead per block
1346  * 3. Block-level prefetching: hides DRAM latency for weight streaming
1347  * 4. F16C scale conversion: uses hardware vcvtsh2ss via CK_FP16_TO_FP32 macro
1348  * 5. Persistent __m256 accumulators: no per-block horizontal sum
1349  *
1350  * Expected speedup: 1.5-2x over naive dispatch loop on AVX (Ivy Bridge+)
1351  * ============================================================================ */
1352 
1353 /**
1354  * @brief Batch GEMM with Q5_0 weights x Q8_0 activations — AVX unrolled
1355  *
1356  * @param A_q8 Input activations in Q8_0 format [M rows of K/32 blocks each]
1357  * @param B_q5 Weights in Q5_0 format [N rows of K/32 blocks each]
1358  * @param bias Optional bias vector [N], NULL if not used
1359  * @param C Output matrix [M x N], row-major FP32
1360  * @param M Batch size (number of tokens)
1361  * @param N Output dimension (number of output features)
1362  * @param K Input dimension (must be multiple of 32)
1363  */
1365  const void *A_q8,
1366  const void *B_q5,
1367  const float *bias,
1368  float *C,
1369  int M, int N, int K)
1370 {
1371  const int nb = K / QK5_0;
1372  const block_q8_0 *a_blocks = (const block_q8_0 *)A_q8;
1373  const block_q5_0 *b_blocks = (const block_q5_0 *)B_q5;
1374  const __m128i mask = _mm_set1_epi8((char)0xF0);
1375 
1376  for (int m = 0; m < M; m++) {
1377  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
1378 
1379  /* ---- 2x N-unrolled path: process 2 weight rows per iteration ---- */
1380  int n = 0;
1381  for (; n + 1 < N; n += 2) {
1382  const block_q5_0 *w0 = b_blocks + (size_t)(n + 0) * nb;
1383  const block_q5_0 *w1 = b_blocks + (size_t)(n + 1) * nb;
1384 
1385  __m256 acc_n0 = _mm256_setzero_ps();
1386  __m256 acc_n1 = _mm256_setzero_ps();
1387 
1388  for (int ib = 0; ib < nb; ib++) {
1389  /* Prefetch next weight blocks (1 cache line = ~2.9 Q5_0 blocks) */
1390  if (ib + 2 < nb) {
1391  _mm_prefetch((const char *)&w0[ib + 2], _MM_HINT_T0);
1392  _mm_prefetch((const char *)&w1[ib + 2], _MM_HINT_T0);
1393  }
1394 
1395  /* Load activation Q8_0 block ONCE — reused for both weight rows */
1396  const __m256i by = _mm256_loadu_si256((const __m256i *)a_row[ib].qs);
1397  const float da = CK_FP16_TO_FP32(a_row[ib].d);
1398 
1399  /* === Weight row 0: inline Q5_0 unpack + dot product === */
1400  {
1401  const float d = da * CK_FP16_TO_FP32(w0[ib].d);
1402 
1403  /* Unpack low 4-bit nibbles → 32 bytes */
1404  __m256i bx = bytes_from_nibbles_32_avx(w0[ib].qs);
1405 
1406  /* Spread 32 high bits → 32 byte mask (0xFF or 0x00) */
1407  const __m256i bxhi = bytes_from_bits_32_avx(w0[ib].qh);
1408  __m128i bxhil = _mm256_castsi256_si128(bxhi);
1409  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1410 
1411  /* Combine: nibble |= (highbit ? 0x10 : 0x00) */
1412  bxhil = _mm_andnot_si128(bxhil, mask);
1413  bxhih = _mm_andnot_si128(bxhih, mask);
1414  __m128i bxl = _mm256_castsi256_si128(bx);
1415  __m128i bxh = _mm256_extractf128_si256(bx, 1);
1416  bxl = _mm_or_si128(bxl, bxhil);
1417  bxh = _mm_or_si128(bxh, bxhih);
1418  bx = MM256_SET_M128I(bxh, bxl);
1419 
1420  /* Signed int8 dot product → 8 float partial sums */
1421  const __m256 q = mul_sum_i8_pairs_float_avx(bx, by);
1422  acc_n0 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), q), acc_n0);
1423  }
1424 
1425  /* === Weight row 1: same activation data, different weights === */
1426  {
1427  const float d = da * CK_FP16_TO_FP32(w1[ib].d);
1428 
1429  __m256i bx = bytes_from_nibbles_32_avx(w1[ib].qs);
1430  const __m256i bxhi = bytes_from_bits_32_avx(w1[ib].qh);
1431  __m128i bxhil = _mm256_castsi256_si128(bxhi);
1432  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1433 
1434  bxhil = _mm_andnot_si128(bxhil, mask);
1435  bxhih = _mm_andnot_si128(bxhih, mask);
1436  __m128i bxl = _mm256_castsi256_si128(bx);
1437  __m128i bxh = _mm256_extractf128_si256(bx, 1);
1438  bxl = _mm_or_si128(bxl, bxhil);
1439  bxh = _mm_or_si128(bxh, bxhih);
1440  bx = MM256_SET_M128I(bxh, bxl);
1441 
1442  const __m256 q = mul_sum_i8_pairs_float_avx(bx, by);
1443  acc_n1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), q), acc_n1);
1444  }
1445  }
1446 
1447  /* Horizontal reduce and store with optional bias */
1448  float s0 = hsum_float_8_avx(acc_n0);
1449  float s1 = hsum_float_8_avx(acc_n1);
1450  if (bias) { s0 += bias[n]; s1 += bias[n + 1]; }
1451  C[(size_t)m * N + n] = s0;
1452  C[(size_t)m * N + n + 1] = s1;
1453  }
1454 
1455  /* ---- Cleanup: remaining odd column ---- */
1456  for (; n < N; n++) {
1457  const block_q5_0 *w = b_blocks + (size_t)n * nb;
1458  __m256 acc = _mm256_setzero_ps();
1459 
1460  for (int ib = 0; ib < nb; ib++) {
1461  const __m256i by = _mm256_loadu_si256((const __m256i *)a_row[ib].qs);
1462  const float d = CK_FP16_TO_FP32(a_row[ib].d) * CK_FP16_TO_FP32(w[ib].d);
1463 
1464  __m256i bx = bytes_from_nibbles_32_avx(w[ib].qs);
1465  const __m256i bxhi = bytes_from_bits_32_avx(w[ib].qh);
1466  __m128i bxhil = _mm256_castsi256_si128(bxhi);
1467  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1468  bxhil = _mm_andnot_si128(bxhil, mask);
1469  bxhih = _mm_andnot_si128(bxhih, mask);
1470  __m128i bxl = _mm256_castsi256_si128(bx);
1471  __m128i bxh = _mm256_extractf128_si256(bx, 1);
1472  bxl = _mm_or_si128(bxl, bxhil);
1473  bxh = _mm_or_si128(bxh, bxhih);
1474  bx = MM256_SET_M128I(bxh, bxl);
1475 
1476  const __m256 q = mul_sum_i8_pairs_float_avx(bx, by);
1477  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), q), acc);
1478  }
1479 
1480  float s = hsum_float_8_avx(acc);
1481  if (bias) s += bias[n];
1482  C[(size_t)m * N + n] = s;
1483  }
1484  }
1485 }
1486 
1487 #endif
1488 
1489 /**
1490  * @brief Auto-dispatch quantized dot product Q5_0 x Q8_0
1491  *
1492  * Dispatch priority:
1493  * 1. AVX512 (best performance on modern Intel/AMD)
1494  * 2. AVX (256-bit float ops, works on Sandy/Ivy Bridge and newer)
1495  * 3. SSSE3 (128-bit fallback)
1496  * 4. Reference scalar (last resort)
1497  */
1498 void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
1499 {
1500 #if defined(__AVX512F__)
1501  vec_dot_q5_0_q8_0_avx512(n, s, vx, vy);
1502 #elif defined(__AVX__)
1503  /* AVX for 256-bit float ops (works on Ivy Bridge and newer) */
1504  vec_dot_q5_0_q8_0_avx(n, s, vx, vy);
1505 #elif defined(__SSSE3__)
1506  /* SSSE3 - most efficient on older CPUs */
1507  vec_dot_q5_0_q8_0_sse(n, s, vx, vy);
1508 #else
1509  vec_dot_q5_0_q8_0_ref(n, s, vx, vy);
1510 #endif
1511 }
1512 
1513 /* ============================================================================
1514  * Quantized GEMV: y = W @ x where W is Q5_0 and x is Q8_0
1515  *
1516  * This is the quantized equivalent of gemv_q5_0, but takes pre-quantized
1517  * input in Q8_0 format. Used for parity testing with llama.cpp.
1518  * ============================================================================ */
1519 
1520 /**
1521  * @brief Matrix-vector multiply with Q5_0 weights and Q8_0 input
1522  *
1523  * @param y Output vector [M]
1524  * @param W Weight matrix in Q5_0 format [M x K]
1525  * @param x_q8 Input vector in Q8_0 format [K]
1526  * @param M Number of output rows
1527  * @param K Number of columns (must be multiple of 32)
1528  */
1529 void gemv_q5_0_q8_0(float *y,
1530  const void *W,
1531  const void *x_q8,
1532  int M, int K)
1533 {
1534  const block_q5_0 *w_blocks = (const block_q5_0 *)W;
1535  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1536  const int blocks_per_row = K / QK5_0;
1537 
1538  for (int row = 0; row < M; row++) {
1539  vec_dot_q5_0_q8_0(K, &y[row],
1540  &w_blocks[row * blocks_per_row],
1541  x_blocks);
1542  }
1543 }
1544 
1545 /**
1546  * @brief Parallel SIMD GEMV for Q5_0 x Q8_0 with prefetching
1547  *
1548  * Each thread processes rows [r0, r1) where r0 = ith * ceil(M/nth).
1549  * Uses vec_dot_q5_0_q8_0 dispatch (auto-selects AVX512/AVX/SSE/scalar).
1550  */
1552  const void *W,
1553  const void *x_q8,
1554  int M, int K,
1555  int ith, int nth)
1556 {
1557  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1558  if (ith < 0 || nth <= 0 || ith >= nth) return;
1559 
1560  const int dr = (M + nth - 1) / nth;
1561  const int r0 = dr * ith;
1562  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1563 
1564  if (r0 >= M) return;
1565 
1566  const block_q5_0 *w_blocks = (const block_q5_0 *)W;
1567  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1568  const int blocks_per_row = K / QK5_0;
1569 
1570 #if defined(__AVX__) || defined(__SSE4_1__)
1571  const int PREFETCH_ROWS = 4;
1572  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1573  const char *row_ptr = (const char *)(w_blocks + (r0 + p) * blocks_per_row);
1574  _mm_prefetch(row_ptr, _MM_HINT_T0);
1575  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1576  }
1577 
1578  for (int row = r0; row < r1; ++row) {
1579  if (row + PREFETCH_ROWS < r1) {
1580  const char *pf = (const char *)(w_blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1581  _mm_prefetch(pf, _MM_HINT_T0);
1582  _mm_prefetch(pf + 64, _MM_HINT_T0);
1583  }
1584 
1585  vec_dot_q5_0_q8_0(K, &y[row],
1586  &w_blocks[row * blocks_per_row],
1587  x_blocks);
1588  }
1589 #else
1590  for (int row = r0; row < r1; row++) {
1591  vec_dot_q5_0_q8_0(K, &y[row],
1592  &w_blocks[row * blocks_per_row],
1593  x_blocks);
1594  }
1595 #endif
1596 }
1597 
1598 /**
1599  * @brief Batch GEMM with Q5_0 weights and Q8_0 activations for prefill
1600  *
1601  * Computes C = A @ B^T + bias where:
1602  * A: [M x K] Q8_0 quantized activations (M tokens, K features)
1603  * B: [N x K] Q5_0 quantized weights (N outputs, K features)
1604  * C: [M x N] FP32 output
1605  *
1606  * This is the INT8 batch kernel for prefill, using pre-quantized activations
1607  * to avoid FP32->Q8_0 conversion overhead per operation.
1608  *
1609  * @param A_q8 Input activations in Q8_0 format [M rows of K/32 blocks each]
1610  * @param B_q5 Weights in Q5_0 format [N rows of K/32 blocks each]
1611  * @param bias Optional bias vector [N], NULL if not used
1612  * @param C Output matrix [M x N], row-major FP32
1613  * @param M Batch size (number of tokens)
1614  * @param N Output dimension (number of output features)
1615  * @param K Input dimension (must be multiple of 32)
1616  */
1618  const void *A_q8,
1619  const void *B_q5,
1620  const float *bias,
1621  float *C,
1622  int M,
1623  int N,
1624  int K)
1625 {
1626  const block_q5_0 *weights = (const block_q5_0 *)B_q5;
1627  const block_q8_0 *inputs = (const block_q8_0 *)A_q8;
1628  const int blocks_per_row = K / QK5_0;
1629 
1630  for (int m = 0; m < M; m++) {
1631  const block_q8_0 *input_row = &inputs[m * blocks_per_row];
1632 
1633  for (int n = 0; n < N; n++) {
1634  const block_q5_0 *weight_row = &weights[n * blocks_per_row];
1635  float *out = &C[m * N + n];
1636 
1637  /* Dispatches to vec_dot_q5_0_q8_0_avx (2x block unrolled) on AVX */
1638  vec_dot_q5_0_q8_0(K, out, weight_row, input_row);
1639 
1640  if (bias) {
1641  *out += bias[n];
1642  }
1643  }
1644  }
1645 }
CPU feature detection and dispatch macros.
Quantization block structures for weight-only quantization.
#define QK5_0
Definition: ckernel_quant.h:67
#define CK_FP16_TO_FP32(x)
void gemm_nt_q5_0_q8_0_unroll_avx(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q5_0_sse_v2(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_q5_0_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemv_q5_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void dequant_q5_0_block(const block_q5_0 *block, float *output)
Dequantize a single Q5_0 block to FP32.
void dequant_q5_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
void gemv_q5_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_0 weights (scalar reference)
void gemv_q5_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void vec_dot_q5_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void gemm_q5_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q5_0 weights.
void gemm_nt_q5_0_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_0 weights: C = A @ B^T.
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
void gemv_q5_0_parallel(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel reference GEMV for Q5_0 × FP32.
float dot_q5_0(const void *w_q5_0, const float *x, int K)
void gemv_q5_0_q8_0_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 x Q8_0 with prefetching.
void gemv_q5_0_parallel_simd(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.
#define C(color)
Definition: show_config.c:39
ck_half d
Definition: ckernel_quant.h:70
uint8_t qh[4]
Definition: ckernel_quant.h:71
uint8_t qs[32/2]
Definition: ckernel_quant.h:72
int8_t qs[32]
int32_t int32_t int32_t int32_t int32_t mask
Definition: tokenizer.h:233