← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q6k_q8k.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q6k_q8k.c
3  * @brief Q6_K (weights) x Q8_K (activations) kernels for inference
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Implements decode-style matvec/matmul where weights are Q6_K and the
15  * activations are quantized on-the-fly to Q8_K. This is inference-only;
16  * no backward pass is provided here.
17  *
18  * Q6_K Format (256 weights per block):
19  * - d: FP16 super-block scale
20  * - ql: 128 bytes (low 4 bits of each weight)
21  * - qh: 64 bytes (high 2 bits of each weight)
22  * - scales: 16 int8 sub-block scales
23  *
24  * Q8_K Format (256 weights per block):
25  * - d: FP32 scale
26  * - qs: 256 int8 values
27  * - bsums: 16 int16 block sums
28  */
29 
30 #include <assert.h>
31 #include <math.h>
32 #include <string.h>
33 #include <stdint.h>
34 #include <stddef.h>
35 
36 #include "ckernel_quant.h"
37 
38 /* Include SIMD headers based on available extensions */
39 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__) || defined(__SSSE3__)
40 #include <immintrin.h>
41 #endif
42 
43 /* Forward declarations for SIMD implementations */
44 void gemv_q6_k_q8_k_avx512(float *y, const void *W, const void *x_q8, int M, int K);
45 void gemv_q6_k_q8_k_avx512_vbmi(float *y, const void *W, const void *x_q8, int M, int K);
46 void gemv_q6_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K);
47 void gemv_q6_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K);
48 void gemv_q6_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K);
49 
50 /* ============================================================================
51  * Reference Implementation
52  * ============================================================================ */
53 
54 /**
55  * @brief Scalar dot product for Q6_K x Q8_K
56  *
57  * Q6_K layout: 256 weights per block
58  * - ql[0..127]: low 4 bits for all 256 weights (packed 2 per byte)
59  * - qh[0..63]: high 2 bits for all 256 weights (packed 4 per byte)
60  * - scales[0..15]: int8 scale for each 16-weight sub-block
61  * - d: FP16 super-block scale
62  *
63  * The dequantization formula for each weight is:
64  * weight = d * scale[sub] * (q6_value - 32)
65  * where q6_value is the 6-bit unsigned value (0..63).
66  */
67 static float dot_q6_k_q8_k_ref(const block_q6_K *w,
68  const block_q8_K *x,
69  int K)
70 {
71  const int nb = K / QK_K;
72  float sumf = 0.0f;
73 
74  for (int i = 0; i < nb; ++i) {
75  const float d = GGML_FP16_TO_FP32(w[i].d) * x[i].d;
76 
77  const uint8_t *ql = w[i].ql;
78  const uint8_t *qh = w[i].qh;
79  const int8_t *sc = w[i].scales;
80  const int8_t *q8 = x[i].qs;
81 
82  /* Process 256 weights in 2 iterations of 128 */
83  for (int n = 0; n < QK_K; n += 128) {
84  /* Each iteration processes 128 weights:
85  * - ql[0..63] contains low 4 bits
86  * - qh[0..31] contains high 2 bits
87  * - Interleaved pattern: weights 0-31, 32-63, 64-95, 96-127
88  */
89  for (int l = 0; l < 32; ++l) {
90  /* Sub-block index: each scale covers 16 weights */
91  const int is = l / 16;
92 
93  /* Extract 6-bit values from packed format */
94  /* q1: weights l+0 (low nibble of ql[l], bits 0-1 of qh[l]) */
95  const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
96  /* q2: weights l+32 (low nibble of ql[l+32], bits 2-3 of qh[l]) */
97  const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
98  /* q3: weights l+64 (high nibble of ql[l], bits 4-5 of qh[l]) */
99  const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
100  /* q4: weights l+96 (high nibble of ql[l+32], bits 6-7 of qh[l]) */
101  const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
102 
103  /* Accumulate: d * scale * q6 * q8 */
104  sumf += d * (float)sc[is + 0] * (float)q1 * (float)q8[l + 0];
105  sumf += d * (float)sc[is + 2] * (float)q2 * (float)q8[l + 32];
106  sumf += d * (float)sc[is + 4] * (float)q3 * (float)q8[l + 64];
107  sumf += d * (float)sc[is + 6] * (float)q4 * (float)q8[l + 96];
108  }
109  q8 += 128;
110  ql += 64;
111  qh += 32;
112  sc += 8;
113  }
114  }
115 
116  return sumf;
117 }
118 
119 void gemv_q6_k_q8_k_ref(float *y,
120  const void *W,
121  const void *x_q8,
122  int M, int K)
123 {
124  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
125  return;
126  }
127 
128  const block_q6_K *blocks = (const block_q6_K *)W;
129  const block_q8_K *x = (const block_q8_K *)x_q8;
130  const int blocks_per_row = K / QK_K;
131 
132  for (int row = 0; row < M; ++row) {
133  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
134  y[row] = dot_q6_k_q8_k_ref(w_row, x, K);
135  }
136 }
137 
138 /* ============================================================================
139  * SSE4.1 Implementation (for Ivy Bridge and older AVX-without-AVX2 CPUs)
140  *
141  * Uses 128-bit SSE operations with maddubs for integer multiply-add.
142  * Handles the -32 offset using bsums from Q8_K.
143  * ============================================================================ */
144 
145 #if defined(__SSSE3__)
146 
147 /* Scale shuffle indices for Q6_K - maps scale index to 16-byte shuffle pattern */
148 static const int8_t q6k_scale_shuffle[8][16] = {
149  { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1 }, /* is=0: scales[0,1] */
150  { 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 }, /* is=1: scales[2,3] */
151  { 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5 }, /* is=2: scales[4,5] */
152  { 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7 }, /* is=3: scales[6,7] */
153  { 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9 }, /* is=4: scales[8,9] */
154  {10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11 }, /* is=5: scales[10,11] */
155  {12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13 }, /* is=6: scales[12,13] */
156  {14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15 }, /* is=7: scales[14,15] */
157 };
158 
159 static float dot_q6_k_q8_k_sse(const block_q6_K *w,
160  const block_q8_K *x,
161  int K)
162 {
163  const int nb = K / QK_K;
164  const __m128i m3 = _mm_set1_epi8(3);
165  const __m128i m15 = _mm_set1_epi8(15);
166 
167  __m128 acc = _mm_setzero_ps();
168 
169  for (int i = 0; i < nb; ++i) {
170  const float d = GGML_FP16_TO_FP32(w[i].d) * x[i].d;
171 
172  const uint8_t *ql = w[i].ql;
173  const uint8_t *qh = w[i].qh;
174  const int8_t *q8 = x[i].qs;
175 
176  /* Load scales and precompute the -32 offset contribution using bsums */
177  const __m128i scales = _mm_loadu_si128((const __m128i *)w[i].scales);
178  const __m128i q8sums_0 = _mm_loadu_si128((const __m128i *)x[i].bsums);
179  const __m128i q8sums_1 = _mm_loadu_si128((const __m128i *)x[i].bsums + 1);
180 
181  /* Compute: sum(scale * bsum) * 32 for the -32 offset */
182  const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
183  const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
184  const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
185  const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
186 
187  __m128i sumi_0 = _mm_setzero_si128();
188  __m128i sumi_1 = _mm_setzero_si128();
189 
190  int is = 0;
191 
192  /* Process 256 weights in 2 iterations of 128 */
193  for (int j = 0; j < QK_K / 128; ++j) {
194  /* Load high bits */
195  const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i *)qh);
196  qh += 16;
197  const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i *)qh);
198  qh += 16;
199 
200  /* Extract and shift high bits into position */
201  const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
202  const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
203  const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
204  const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
205  const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
206  const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
207  const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
208  const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
209 
210  /* Load low bits */
211  const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i *)ql);
212  ql += 16;
213  const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i *)ql);
214  ql += 16;
215  const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i *)ql);
216  ql += 16;
217  const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i *)ql);
218  ql += 16;
219 
220  /* Combine low and high bits to get 6-bit values (unsigned 0..63) */
221  const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
222  const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
223  const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
224  const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
225  const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
226  const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
227  const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
228  const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
229 
230  /* Load Q8_K values */
231  const __m128i q8_0 = _mm_loadu_si128((const __m128i *)q8);
232  q8 += 16;
233  const __m128i q8_1 = _mm_loadu_si128((const __m128i *)q8);
234  q8 += 16;
235  const __m128i q8_2 = _mm_loadu_si128((const __m128i *)q8);
236  q8 += 16;
237  const __m128i q8_3 = _mm_loadu_si128((const __m128i *)q8);
238  q8 += 16;
239  const __m128i q8_4 = _mm_loadu_si128((const __m128i *)q8);
240  q8 += 16;
241  const __m128i q8_5 = _mm_loadu_si128((const __m128i *)q8);
242  q8 += 16;
243  const __m128i q8_6 = _mm_loadu_si128((const __m128i *)q8);
244  q8 += 16;
245  const __m128i q8_7 = _mm_loadu_si128((const __m128i *)q8);
246  q8 += 16;
247 
248  /* Multiply: maddubs treats first arg as unsigned, second as signed */
249  __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
250  __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
251  __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
252  __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
253  __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
254  __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
255  __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
256  __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
257 
258  /* Get scales for this iteration */
259  const __m128i scale_0 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 0]));
260  const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 1]));
261  const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 2]));
262  const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 3]));
263  is += 4;
264 
265  /* Scale the products and widen to 32-bit */
266  p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
267  p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
268  p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
269  p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
270  p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
271  p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
272  p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
273  p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
274 
275  /* Accumulate */
276  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
277  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
278  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
279  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
280  }
281 
282  /* Subtract the -32 offset contribution */
283  sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
284  sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
285 
286  /* Combine and convert to float */
287  __m128i sumi = _mm_add_epi32(sumi_0, sumi_1);
288  __m128 sumf_vec = _mm_mul_ps(_mm_set1_ps(d), _mm_cvtepi32_ps(sumi));
289 
290  /* Horizontal sum */
291  sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
292  sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
293  acc = _mm_add_ss(acc, sumf_vec);
294  }
295 
296  return _mm_cvtss_f32(acc);
297 }
298 
299 void gemv_q6_k_q8_k_sse(float *y,
300  const void *W,
301  const void *x_q8,
302  int M, int K)
303 {
304  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
305  return;
306  }
307 
308  const block_q6_K *blocks = (const block_q6_K *)W;
309  const block_q8_K *x = (const block_q8_K *)x_q8;
310  const int blocks_per_row = K / QK_K;
311 
312  for (int row = 0; row < M; ++row) {
313  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
314  y[row] = dot_q6_k_q8_k_sse(w_row, x, K);
315  }
316 }
317 #endif /* __SSSE3__ */
318 
319 /* ============================================================================
320  * AVX Implementation (for Sandy/Ivy Bridge - AVX without AVX2)
321  *
322  * Same as SSE but with prefetching for next block.
323  * Uses 128-bit integer ops (AVX doesn't add 256-bit int ops).
324  * ============================================================================ */
325 
326 #if defined(__AVX__) && !defined(__AVX2__)
327 
328 static float dot_q6_k_q8_k_avx(const block_q6_K *w,
329  const block_q8_K *x,
330  int K)
331 {
332  const int nb = K / QK_K;
333  const __m128i m3 = _mm_set1_epi8(3);
334  const __m128i m15 = _mm_set1_epi8(15);
335 
336  __m128 acc = _mm_setzero_ps();
337 
338  for (int i = 0; i < nb; ++i) {
339  const float d = GGML_FP16_TO_FP32(w[i].d) * x[i].d;
340 
341  const uint8_t *ql = w[i].ql;
342  const uint8_t *qh = w[i].qh;
343  const int8_t *q8 = x[i].qs;
344 
345  /* Prefetch next block */
346  if (i + 1 < nb) {
347  _mm_prefetch((const char *)&w[i + 1], _MM_HINT_T0);
348  _mm_prefetch((const char *)&x[i + 1], _MM_HINT_T0);
349  }
350 
351  /* Load scales and precompute the -32 offset contribution using bsums */
352  const __m128i scales = _mm_loadu_si128((const __m128i *)w[i].scales);
353  const __m128i q8sums_0 = _mm_loadu_si128((const __m128i *)x[i].bsums);
354  const __m128i q8sums_1 = _mm_loadu_si128((const __m128i *)x[i].bsums + 1);
355 
356  /* Compute: sum(scale * bsum) * 32 for the -32 offset */
357  const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
358  const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
359  const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
360  const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
361 
362  __m128i sumi_0 = _mm_setzero_si128();
363  __m128i sumi_1 = _mm_setzero_si128();
364 
365  int is = 0;
366 
367  /* Process 256 weights in 2 iterations of 128 */
368  for (int j = 0; j < QK_K / 128; ++j) {
369  /* Load high bits */
370  const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i *)qh);
371  qh += 16;
372  const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i *)qh);
373  qh += 16;
374 
375  /* Extract and shift high bits into position */
376  const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
377  const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
378  const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
379  const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
380  const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
381  const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
382  const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
383  const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
384 
385  /* Load low bits */
386  const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i *)ql);
387  ql += 16;
388  const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i *)ql);
389  ql += 16;
390  const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i *)ql);
391  ql += 16;
392  const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i *)ql);
393  ql += 16;
394 
395  /* Combine low and high bits to get 6-bit values (unsigned 0..63) */
396  const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
397  const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
398  const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
399  const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
400  const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
401  const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
402  const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
403  const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
404 
405  /* Load Q8_K values */
406  const __m128i q8_0 = _mm_loadu_si128((const __m128i *)q8);
407  q8 += 16;
408  const __m128i q8_1 = _mm_loadu_si128((const __m128i *)q8);
409  q8 += 16;
410  const __m128i q8_2 = _mm_loadu_si128((const __m128i *)q8);
411  q8 += 16;
412  const __m128i q8_3 = _mm_loadu_si128((const __m128i *)q8);
413  q8 += 16;
414  const __m128i q8_4 = _mm_loadu_si128((const __m128i *)q8);
415  q8 += 16;
416  const __m128i q8_5 = _mm_loadu_si128((const __m128i *)q8);
417  q8 += 16;
418  const __m128i q8_6 = _mm_loadu_si128((const __m128i *)q8);
419  q8 += 16;
420  const __m128i q8_7 = _mm_loadu_si128((const __m128i *)q8);
421  q8 += 16;
422 
423  /* Multiply: maddubs treats first arg as unsigned, second as signed */
424  __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
425  __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
426  __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
427  __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
428  __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
429  __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
430  __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
431  __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
432 
433  /* Get scales for this iteration */
434  const __m128i scale_0 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 0]));
435  const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 1]));
436  const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 2]));
437  const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)q6k_scale_shuffle[is + 3]));
438  is += 4;
439 
440  /* Scale the products and widen to 32-bit */
441  p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
442  p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
443  p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
444  p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
445  p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
446  p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
447  p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
448  p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
449 
450  /* Accumulate */
451  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
452  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
453  sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
454  sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
455  }
456 
457  /* Subtract the -32 offset contribution */
458  sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
459  sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
460 
461  /* Combine and convert to float */
462  __m128i sumi = _mm_add_epi32(sumi_0, sumi_1);
463  __m128 sumf_vec = _mm_mul_ps(_mm_set1_ps(d), _mm_cvtepi32_ps(sumi));
464 
465  /* Horizontal sum */
466  sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
467  sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
468  acc = _mm_add_ss(acc, sumf_vec);
469  }
470 
471  return _mm_cvtss_f32(acc);
472 }
473 
474 void gemv_q6_k_q8_k_avx(float *y,
475  const void *W,
476  const void *x_q8,
477  int M, int K)
478 {
479  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
480  return;
481  }
482 
483  const block_q6_K *blocks = (const block_q6_K *)W;
484  const block_q8_K *x = (const block_q8_K *)x_q8;
485  const int blocks_per_row = K / QK_K;
486 
487  for (int row = 0; row < M; ++row) {
488  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
489  y[row] = dot_q6_k_q8_k_avx(w_row, x, K);
490  }
491 }
492 
493 #endif /* __AVX__ && !__AVX2__ */
494 
495 /* ============================================================================
496  * AVX2 Implementation (for modern CPUs with AVX2)
497  * ============================================================================ */
498 
499 #if defined(__AVX2__)
500 
501 /* Scale shuffle for AVX2 - 32-byte version */
502 static const int8_t q6k_scale_shuffle_avx2[4][32] = {
503  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
504  { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
505  { 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5 },
506  { 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 },
507 };
508 
509 static inline __m128i get_scale_shuffle_avx2(int i) {
510  static const uint8_t patterns[8][16] = {
511  { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1 },
512  { 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 },
513  { 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5 },
514  { 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7 },
515  { 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9 },
516  {10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11 },
517  {12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13 },
518  {14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15 },
519  };
520  return _mm_loadu_si128((const __m128i *)patterns[i]);
521 }
522 
523 static float dot_q6_k_q8_k_avx2(const block_q6_K *w,
524  const block_q8_K *x,
525  int K)
526 {
527  const int nb = K / QK_K;
528  const __m256i m4 = _mm256_set1_epi8(0xF);
529  const __m256i m2 = _mm256_set1_epi8(3);
530  const __m256i m32s = _mm256_set1_epi8(32);
531 
532  __m256 acc = _mm256_setzero_ps();
533 
534  for (int i = 0; i < nb; ++i) {
535  const float d = GGML_FP16_TO_FP32(w[i].d) * x[i].d;
536 
537  const uint8_t *q4 = w[i].ql;
538  const uint8_t *qh = w[i].qh;
539  const int8_t *q8 = x[i].qs;
540 
541  const __m128i scales = _mm_loadu_si128((const __m128i *)w[i].scales);
542 
543  __m256i sumi = _mm256_setzero_si256();
544  int is = 0;
545 
546  for (int j = 0; j < QK_K / 128; ++j) {
547  const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 0));
548  const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 1));
549  const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 2));
550  const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 3));
551  is += 4;
552 
553  const __m256i q4bits1 = _mm256_loadu_si256((const __m256i *)q4);
554  q4 += 32;
555  const __m256i q4bits2 = _mm256_loadu_si256((const __m256i *)q4);
556  q4 += 32;
557  const __m256i q4bitsH = _mm256_loadu_si256((const __m256i *)qh);
558  qh += 32;
559 
560  const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
561  const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
562  const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
563  const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
564 
565  const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
566  const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
567  const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
568  const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
569 
570  const __m256i q8_0 = _mm256_loadu_si256((const __m256i *)q8);
571  q8 += 32;
572  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8);
573  q8 += 32;
574  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8);
575  q8 += 32;
576  const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8);
577  q8 += 32;
578 
579  /* Compute -32 * q8 contribution */
580  __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
581  __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
582  __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
583  __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
584 
585  /* Multiply q4 * q8 */
586  __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
587  __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
588  __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
589  __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
590 
591  /* Subtract offset: (q4 - 32) * q8 = q4*q8 - 32*q8 */
592  p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
593  p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
594  p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
595  p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
596 
597  /* Apply scales */
598  p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
599  p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
600  p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
601  p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
602 
603  sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
604  sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
605  }
606 
607  acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
608  }
609 
610  /* Horizontal sum */
611  __m128 hi = _mm256_extractf128_ps(acc, 1);
612  __m128 lo = _mm256_castps256_ps128(acc);
613  __m128 sum128 = _mm_add_ps(hi, lo);
614  sum128 = _mm_hadd_ps(sum128, sum128);
615  sum128 = _mm_hadd_ps(sum128, sum128);
616  return _mm_cvtss_f32(sum128);
617 }
618 
619 void gemv_q6_k_q8_k_avx2(float *y,
620  const void *W,
621  const void *x_q8,
622  int M, int K)
623 {
624  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
625  return;
626  }
627 
628  const block_q6_K *blocks = (const block_q6_K *)W;
629  const block_q8_K *x = (const block_q8_K *)x_q8;
630  const int blocks_per_row = K / QK_K;
631 
632  for (int row = 0; row < M; ++row) {
633  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
634  y[row] = dot_q6_k_q8_k_avx2(w_row, x, K);
635  }
636 }
637 #endif /* __AVX2__ */
638 
639 /* ============================================================================
640  * AVX-512 Implementation
641  *
642  * Uses 512-bit ZMM registers to process 64 bytes at a time.
643  * Processes entire 256-element Q6_K block in fewer iterations.
644  * ============================================================================ */
645 
646 #if defined(__AVX512F__) && defined(__AVX512BW__) && defined(__AVX512VBMI__)
647 
648 /**
649  * @brief AVX-512 dot product for Q6_K x Q8_K with VBMI
650  *
651  * Uses AVX-512 VBMI for efficient byte permutation.
652  */
653 static float dot_q6_k_q8_k_avx512_vbmi(const block_q6_K *w,
654  const block_q8_K *x,
655  int K)
656 {
657  const int nb = K / QK_K;
658  const __m512i m4 = _mm512_set1_epi8(0xF);
659  const __m512i m2 = _mm512_set1_epi8(3);
660  const __m512i m32s = _mm512_set1_epi8(32);
661 
662  __m512 acc = _mm512_setzero_ps();
663 
664  for (int i = 0; i < nb; ++i) {
665  const float d = GGML_FP16_TO_FP32(w[i].d) * x[i].d;
666 
667  const uint8_t *ql = w[i].ql;
668  const uint8_t *qh = w[i].qh;
669  const int8_t *q8 = x[i].qs;
670  const int8_t *sc = w[i].scales;
671 
672  __m512i sumi = _mm512_setzero_si512();
673 
674  /* Process 256 weights in one iteration using AVX-512 */
675  /* Load 64 bytes of low bits (covers 128 weights, need 2 loads for full block) */
676  const __m512i q4bits1 = _mm512_loadu_si512((const __m512i *)ql); /* ql[0..63] */
677  const __m512i q4bits2 = _mm512_loadu_si512((const __m512i *)(ql + 64)); /* ql[64..127] */
678 
679  /* Load 64 bytes of high bits */
680  const __m512i q4bitsH = _mm512_loadu_si512((const __m512i *)qh);
681 
682  /* Extract high 2-bit contributions for each group of 32 weights */
683  /* Group 0: bits 0-1 of qh -> weights 0-31 */
684  const __m512i q4h_0 = _mm512_slli_epi16(_mm512_and_si512(q4bitsH, m2), 4);
685  /* Group 1: bits 2-3 of qh -> weights 32-63 */
686  const __m512i q4h_1 = _mm512_slli_epi16(_mm512_and_si512(_mm512_srli_epi16(q4bitsH, 2), m2), 4);
687  /* Group 2: bits 4-5 of qh -> weights 64-95 */
688  const __m512i q4h_2 = _mm512_slli_epi16(_mm512_and_si512(_mm512_srli_epi16(q4bitsH, 4), m2), 4);
689  /* Group 3: bits 6-7 of qh -> weights 96-127 */
690  const __m512i q4h_3 = _mm512_slli_epi16(_mm512_and_si512(_mm512_srli_epi16(q4bitsH, 6), m2), 4);
691 
692  /* Combine low nibbles with high bits to get 6-bit values (0-63) */
693  /* First 64 weights: low nibbles of ql[0..63] */
694  const __m512i q6_0 = _mm512_or_si512(_mm512_and_si512(q4bits1, m4), q4h_0);
695  const __m512i q6_1 = _mm512_or_si512(_mm512_and_si512(q4bits2, m4), q4h_1);
696  /* Second 64 weights: high nibbles of ql[0..63] */
697  const __m512i q6_2 = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(q4bits1, 4), m4), q4h_2);
698  const __m512i q6_3 = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(q4bits2, 4), m4), q4h_3);
699 
700  /* Load Q8_K values (256 int8 values = 4 x 64) */
701  const __m512i q8_0 = _mm512_loadu_si512((const __m512i *)q8);
702  const __m512i q8_1 = _mm512_loadu_si512((const __m512i *)(q8 + 64));
703  const __m512i q8_2 = _mm512_loadu_si512((const __m512i *)(q8 + 128));
704  const __m512i q8_3 = _mm512_loadu_si512((const __m512i *)(q8 + 192));
705 
706  /* Compute 32 * q8 for the offset subtraction */
707  __m512i q8s_0 = _mm512_maddubs_epi16(m32s, q8_0);
708  __m512i q8s_1 = _mm512_maddubs_epi16(m32s, q8_1);
709  __m512i q8s_2 = _mm512_maddubs_epi16(m32s, q8_2);
710  __m512i q8s_3 = _mm512_maddubs_epi16(m32s, q8_3);
711 
712  /* Multiply unsigned q6 * signed q8 */
713  __m512i p16_0 = _mm512_maddubs_epi16(q6_0, q8_0);
714  __m512i p16_1 = _mm512_maddubs_epi16(q6_1, q8_1);
715  __m512i p16_2 = _mm512_maddubs_epi16(q6_2, q8_2);
716  __m512i p16_3 = _mm512_maddubs_epi16(q6_3, q8_3);
717 
718  /* Subtract offset: (q6 - 32) * q8 = q6*q8 - 32*q8 */
719  p16_0 = _mm512_sub_epi16(p16_0, q8s_0);
720  p16_1 = _mm512_sub_epi16(p16_1, q8s_1);
721  p16_2 = _mm512_sub_epi16(p16_2, q8s_2);
722  p16_3 = _mm512_sub_epi16(p16_3, q8s_3);
723 
724  /* Load and broadcast scales using VBMI permute
725  * Each scale applies to 16 weights, so we need to broadcast appropriately
726  * scales[0..15] for the 16 sub-blocks */
727  const __m128i scales_128 = _mm_loadu_si128((const __m128i *)sc);
728 
729  /* Create scale broadcast patterns for 64 weights (4 scales per 64 weights) */
730  /* Pattern: each scale repeated 16 times for 16 weights */
731  const __m512i scale_idx_0 = _mm512_set_epi8(
732  3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,
733  2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
734  1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
735  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);
736  const __m512i scale_idx_1 = _mm512_set_epi8(
737  7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
738  6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,
739  5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,
740  4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4);
741  const __m512i scale_idx_2 = _mm512_set_epi8(
742  11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,
743  10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,
744  9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
745  8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8);
746  const __m512i scale_idx_3 = _mm512_set_epi8(
747  15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,
748  14,14,14,14,14,14,14,14,14,14,14,14,14,14,14,14,
749  13,13,13,13,13,13,13,13,13,13,13,13,13,13,13,13,
750  12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12);
751 
752  /* Broadcast scales to 512-bit using VBMI permutexvar */
753  const __m512i scales_512 = _mm512_broadcast_i32x4(scales_128);
754  const __m512i sc_0 = _mm512_permutexvar_epi8(scale_idx_0, scales_512);
755  const __m512i sc_1 = _mm512_permutexvar_epi8(scale_idx_1, scales_512);
756  const __m512i sc_2 = _mm512_permutexvar_epi8(scale_idx_2, scales_512);
757  const __m512i sc_3 = _mm512_permutexvar_epi8(scale_idx_3, scales_512);
758 
759  /* Sign-extend scales to 16-bit and multiply with products */
760  /* For efficiency, we process in two halves (low and high 256 bits) */
761  __m512i p32_0 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_0)), p16_0);
762  __m512i p32_1 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_1)), p16_1);
763  __m512i p32_2 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_2)), p16_2);
764  __m512i p32_3 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_3)), p16_3);
765 
766  /* Accumulate */
767  sumi = _mm512_add_epi32(sumi, p32_0);
768  sumi = _mm512_add_epi32(sumi, p32_1);
769  sumi = _mm512_add_epi32(sumi, p32_2);
770  sumi = _mm512_add_epi32(sumi, p32_3);
771 
772  /* Scale by d and accumulate */
773  acc = _mm512_fmadd_ps(_mm512_set1_ps(d), _mm512_cvtepi32_ps(sumi), acc);
774  }
775 
776  return _mm512_reduce_add_ps(acc);
777 }
778 
779 void gemv_q6_k_q8_k_avx512_vbmi(float *y,
780  const void *W,
781  const void *x_q8,
782  int M, int K)
783 {
784  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
785  return;
786  }
787 
788  const block_q6_K *blocks = (const block_q6_K *)W;
789  const block_q8_K *x = (const block_q8_K *)x_q8;
790  const int blocks_per_row = K / QK_K;
791 
792  for (int row = 0; row < M; ++row) {
793  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
794  y[row] = dot_q6_k_q8_k_avx512_vbmi(w_row, x, K);
795  }
796 }
797 
798 #endif /* __AVX512F__ && __AVX512BW__ && __AVX512VBMI__ */
799 
800 #if defined(__AVX512F__) && defined(__AVX512BW__)
801 
802 /**
803  * @brief AVX-512 dot product for Q6_K x Q8_K
804  *
805  * Works on all AVX-512 CPUs (Skylake-X and newer).
806  * Uses same algorithm as AVX2, but benefits from AVX-512's wider FMA
807  * and efficient horizontal reduction.
808  */
809 static float dot_q6_k_q8_k_avx512(const block_q6_K *w,
810  const block_q8_K *x,
811  int K)
812 {
813  const int nb = K / QK_K;
814  const __m256i m4 = _mm256_set1_epi8(0xF);
815  const __m256i m2 = _mm256_set1_epi8(3);
816  const __m256i m32s = _mm256_set1_epi8(32);
817 
818  /* Use 256-bit float accumulator, same as AVX2 */
819  __m256 acc = _mm256_setzero_ps();
820 
821  for (int i = 0; i < nb; ++i) {
822  const float d = GGML_FP16_TO_FP32(w[i].d) * x[i].d;
823 
824  const uint8_t *q4 = w[i].ql;
825  const uint8_t *qh = w[i].qh;
826  const int8_t *q8 = x[i].qs;
827 
828  const __m128i scales = _mm_loadu_si128((const __m128i *)w[i].scales);
829 
830  /* Use 256-bit integer accumulator, same as AVX2 */
831  __m256i sumi = _mm256_setzero_si256();
832  int is = 0;
833 
834  /* Process 256 weights in 2 iterations of 128 (same structure as AVX2) */
835  for (int j = 0; j < QK_K / 128; ++j) {
836  /* Get scale shuffle patterns - identical to AVX2 */
837  static const uint8_t patterns[8][16] = {
838  { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1 },
839  { 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 },
840  { 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5 },
841  { 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7 },
842  { 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9 },
843  {10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11 },
844  {12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13 },
845  {14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15 },
846  };
847 
848  const __m128i scale_0 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)patterns[is + 0]));
849  const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)patterns[is + 1]));
850  const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)patterns[is + 2]));
851  const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_loadu_si128((const __m128i *)patterns[is + 3]));
852  is += 4;
853 
854  /* Load low bits */
855  const __m256i q4bits1 = _mm256_loadu_si256((const __m256i *)q4);
856  q4 += 32;
857  const __m256i q4bits2 = _mm256_loadu_si256((const __m256i *)q4);
858  q4 += 32;
859  const __m256i q4bitsH = _mm256_loadu_si256((const __m256i *)qh);
860  qh += 32;
861 
862  /* Extract high 2-bit contributions */
863  const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
864  const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
865  const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
866  const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
867 
868  /* Combine low + high bits to get 6-bit values */
869  const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
870  const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
871  const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
872  const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
873 
874  /* Load Q8_K values */
875  const __m256i q8_0 = _mm256_loadu_si256((const __m256i *)q8);
876  q8 += 32;
877  const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8);
878  q8 += 32;
879  const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8);
880  q8 += 32;
881  const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8);
882  q8 += 32;
883 
884  /* Compute 32 * q8 for offset */
885  __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
886  __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
887  __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
888  __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
889 
890  /* Multiply q4 * q8 (unsigned * signed) */
891  __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
892  __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
893  __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
894  __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
895 
896  /* Subtract offset: (q4 - 32) * q8 = q4*q8 - 32*q8 */
897  p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
898  p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
899  p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
900  p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
901 
902  /* Apply scales - produces 8 int32 each */
903  p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
904  p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
905  p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
906  p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
907 
908  /* Accumulate all 4 into sumi (same as AVX2) */
909  sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
910  sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
911  }
912 
913  /* Scale by d and accumulate */
914  acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), acc);
915  }
916 
917  /* Horizontal sum - use AVX-512 reduce for efficiency */
918  __m128 hi = _mm256_extractf128_ps(acc, 1);
919  __m128 lo = _mm256_castps256_ps128(acc);
920  __m128 sum128 = _mm_add_ps(hi, lo);
921  sum128 = _mm_hadd_ps(sum128, sum128);
922  sum128 = _mm_hadd_ps(sum128, sum128);
923  return _mm_cvtss_f32(sum128);
924 }
925 
926 void gemv_q6_k_q8_k_avx512(float *y,
927  const void *W,
928  const void *x_q8,
929  int M, int K)
930 {
931  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
932  return;
933  }
934 
935  const block_q6_K *blocks = (const block_q6_K *)W;
936  const block_q8_K *x = (const block_q8_K *)x_q8;
937  const int blocks_per_row = K / QK_K;
938 
939  for (int row = 0; row < M; ++row) {
940  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
941  y[row] = dot_q6_k_q8_k_avx512(w_row, x, K);
942  }
943 }
944 
945 #endif /* __AVX512F__ && __AVX512BW__ */
946 
947 /* ============================================================================
948  * Dispatch Functions
949  * ============================================================================ */
950 
951 /**
952  * @brief Q6_K x Q8_K dot product (single row)
953  */
954 void vec_dot_q6_k_q8_k(int n, float *s, const void *vx, const void *vy)
955 {
956  if (!s || !vx || !vy || n <= 0) {
957  return;
958  }
959 
960  const block_q6_K *x = (const block_q6_K *)vx;
961  const block_q8_K *y = (const block_q8_K *)vy;
962 
963  /* Dispatch based on available SIMD */
964 #if defined(__AVX512F__) && defined(__AVX512BW__)
965  *s = dot_q6_k_q8_k_avx512(x, y, n);
966 #elif defined(__AVX2__)
967  *s = dot_q6_k_q8_k_avx2(x, y, n);
968 #elif defined(__AVX__) && !defined(__AVX2__)
969  *s = dot_q6_k_q8_k_avx(x, y, n);
970 #elif defined(__SSSE3__)
971  *s = dot_q6_k_q8_k_sse(x, y, n);
972 #else
973  *s = dot_q6_k_q8_k_ref(x, y, n);
974 #endif
975 }
976 
977 /**
978  * @brief GEMV: y = W @ x where W is Q6_K and x is Q8_K
979  */
980 void gemv_q6_k_q8_k(float *y,
981  const void *W,
982  const void *x_q8,
983  int M, int K)
984 {
985  /* AVX-512 uses same algorithm as AVX2 (matches llama.cpp) */
986 #if defined(__AVX512F__) && defined(__AVX512BW__)
987  gemv_q6_k_q8_k_avx512(y, W, x_q8, M, K);
988 #elif defined(__AVX2__)
989  gemv_q6_k_q8_k_avx2(y, W, x_q8, M, K);
990 #elif defined(__AVX__)
991  gemv_q6_k_q8_k_avx(y, W, x_q8, M, K);
992 #elif defined(__SSSE3__)
993  gemv_q6_k_q8_k_sse(y, W, x_q8, M, K);
994 #else
995  gemv_q6_k_q8_k_ref(y, W, x_q8, M, K);
996 #endif
997 }
998 
999 /* ============================================================================
1000  * PARALLEL VERSIONS (for parallel orchestration)
1001  *
1002  * These receive ith (thread index) and nth (total threads) from orchestration.
1003  * OpenMP lives in orchestration layer, NOT here.
1004  *
1005  * Naming: *_parallel = receives ith/nth, processes only its portion
1006  * ============================================================================ */
1007 
1008 /**
1009  * @brief Parallel reference GEMV for Q6_K × Q8_K
1010  *
1011  * Caller provides ith (thread index) and nth (total threads).
1012  * Each thread processes rows [r0, r1).
1013  */
1015  const void *W,
1016  const void *x_q8,
1017  int M, int K,
1018  int ith, int nth)
1019 {
1020  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1021  if (ith < 0 || nth <= 0 || ith >= nth) return;
1022 
1023  /* Compute row range for this thread */
1024  const int dr = (M + nth - 1) / nth;
1025  const int r0 = dr * ith;
1026  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1027 
1028  if (r0 >= M) return;
1029 
1030  const block_q6_K *blocks = (const block_q6_K *)W;
1031  const block_q8_K *x = (const block_q8_K *)x_q8;
1032  const int blocks_per_row = K / QK_K;
1033 
1034  for (int row = r0; row < r1; ++row) {
1035  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
1036  y[row] = dot_q6_k_q8_k_ref(w_row, x, K);
1037  }
1038 }
1039 
1040 /**
1041  * @brief Parallel SIMD GEMV for Q6_K × Q8_K
1042  *
1043  * Uses best available SIMD (AVX/SSE) with row prefetching.
1044  * Caller provides ith/nth from OpenMP region.
1045  */
1047  const void *W,
1048  const void *x_q8,
1049  int M, int K,
1050  int ith, int nth)
1051 {
1052  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1053  if (ith < 0 || nth <= 0 || ith >= nth) return;
1054 
1055  const int dr = (M + nth - 1) / nth;
1056  const int r0 = dr * ith;
1057  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1058 
1059  if (r0 >= M) return;
1060 
1061  const block_q6_K *blocks = (const block_q6_K *)W;
1062  const block_q8_K *x = (const block_q8_K *)x_q8;
1063  const int blocks_per_row = K / QK_K;
1064 
1065 #if defined(__AVX__) || defined(__SSSE3__)
1066  /* Prefetch first few rows */
1067  const int PREFETCH_ROWS = 4;
1068  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1069  const char *row_ptr = (const char *)(blocks + (r0 + p) * blocks_per_row);
1070  _mm_prefetch(row_ptr, _MM_HINT_T0);
1071  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1072  }
1073 
1074  for (int row = r0; row < r1; ++row) {
1075  /* Prefetch rows ahead */
1076  if (row + PREFETCH_ROWS < r1) {
1077  const char *prefetch_ptr = (const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1078  _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
1079  _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
1080  }
1081 
1082  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
1083 #if defined(__AVX2__)
1084  y[row] = dot_q6_k_q8_k_avx2(w_row, x, K);
1085 #elif defined(__AVX__)
1086  y[row] = dot_q6_k_q8_k_avx(w_row, x, K);
1087 #else
1088  y[row] = dot_q6_k_q8_k_sse(w_row, x, K);
1089 #endif
1090  }
1091 #else
1092  /* Fallback to reference */
1093  for (int row = r0; row < r1; ++row) {
1094  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
1095  y[row] = dot_q6_k_q8_k_ref(w_row, x, K);
1096  }
1097 #endif
1098 }
1099 
1100 /**
1101  * @brief GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K
1102  *
1103  * @param Y Output matrix [N x M] in row-major
1104  * @param W Weight matrix in Q6_K format [M x K]
1105  * @param X_q8 Input matrix in Q8_K format [N x K]
1106  * @param M Number of output rows (output dim)
1107  * @param N Number of input vectors (batch size)
1108  * @param K Input dimension
1109  */
1110 void gemm_q6_k_q8_k(float *Y,
1111  const void *W,
1112  const void *X_q8,
1113  int M, int N, int K)
1114 {
1115  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
1116  return;
1117  }
1118 
1119  const block_q8_K *X = (const block_q8_K *)X_q8;
1120  const int blocks_per_vec = K / QK_K;
1121 
1122  for (int n = 0; n < N; ++n) {
1123  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
1124  gemv_q6_k_q8_k(&Y[n * M], W, x_row, M, K);
1125  }
1126 }
1127 
1128 /**
1129  * @brief NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K
1130  *
1131  * This is the typical inference pattern:
1132  * - A: Activations in Q8_K format [M x K]
1133  * - B: Weights in Q6_K format [N x K]
1134  * - C: Output [M x N]
1135  *
1136  * @param A_q8 Input activations in Q8_K format
1137  * @param B Weight matrix in Q6_K format
1138  * @param bias Optional bias vector [N]
1139  * @param C Output matrix
1140  * @param M Batch size (number of tokens)
1141  * @param N Output dimension
1142  * @param K Input dimension
1143  */
1144 void gemm_nt_q6_k_q8_k(const void *A_q8,
1145  const void *B,
1146  const float *bias,
1147  float *C,
1148  int M, int N, int K)
1149 {
1150  if (!A_q8 || !B || !C) {
1151  return;
1152  }
1153  if (M <= 0 || N <= 0 || K <= 0) {
1154  return;
1155  }
1156 
1157  gemm_q6_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
1158 
1159  if (!bias) {
1160  return;
1161  }
1162 
1163  for (int i = 0; i < M; ++i) {
1164  float *row = C + (size_t)i * (size_t)N;
1165  for (int j = 0; j < N; ++j) {
1166  row[j] += bias[j];
1167  }
1168  }
1169 }
Quantization block structures for weight-only quantization.
#define GGML_FP16_TO_FP32
#define QK_K
void gemv_q6_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemm_q6_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
void gemv_q6_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void vec_dot_q6_k_q8_k(int n, float *s, const void *vx, const void *vy)
Q6_K x Q8_K dot product (single row)
void gemv_q6_k_q8_k_avx512_vbmi(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q6_K × Q8_K.
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
void gemv_q6_k_q8_k_avx512(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q6_K × Q8_K.
static float dot_q6_k_q8_k_ref(const block_q6_K *w, const block_q8_K *x, int K)
Scalar dot product for Q6_K x Q8_K.
#define C(color)
Definition: show_config.c:39
uint8_t ql[256/2]
int8_t scales[256/16]
uint8_t qh[256/4]
int8_t qs[256]