39 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__) || defined(__SSSE3__)
40 #include <immintrin.h>
71 const int nb = K /
QK_K;
74 for (
int i = 0; i < nb; ++i) {
77 const uint8_t *ql = w[i].
ql;
78 const uint8_t *qh = w[i].
qh;
79 const int8_t *sc = w[i].
scales;
80 const int8_t *q8 = x[i].
qs;
83 for (
int n = 0; n <
QK_K; n += 128) {
89 for (
int l = 0; l < 32; ++l) {
91 const int is = l / 16;
95 const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
97 const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
99 const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
101 const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
104 sumf += d * (float)sc[is + 0] * (
float)q1 * (float)q8[l + 0];
105 sumf += d * (float)sc[is + 2] * (
float)q2 * (float)q8[l + 32];
106 sumf += d * (float)sc[is + 4] * (
float)q3 * (float)q8[l + 64];
107 sumf += d * (float)sc[is + 6] * (
float)q4 * (float)q8[l + 96];
124 if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
130 const int blocks_per_row = K /
QK_K;
132 for (
int row = 0; row < M; ++row) {
133 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
145 #if defined(__SSSE3__)
148 static const int8_t q6k_scale_shuffle[8][16] = {
149 { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1 },
150 { 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 },
151 { 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5 },
152 { 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7 },
153 { 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9 },
154 {10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11 },
155 {12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13 },
156 {14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15 },
159 static float dot_q6_k_q8_k_sse(
const block_q6_K *w,
163 const int nb = K /
QK_K;
164 const __m128i m3 = _mm_set1_epi8(3);
165 const __m128i m15 = _mm_set1_epi8(15);
167 __m128 acc = _mm_setzero_ps();
169 for (
int i = 0; i < nb; ++i) {
172 const uint8_t *ql = w[i].
ql;
173 const uint8_t *qh = w[i].
qh;
174 const int8_t *q8 = x[i].
qs;
177 const __m128i scales = _mm_loadu_si128((
const __m128i *)w[i].scales);
178 const __m128i q8sums_0 = _mm_loadu_si128((
const __m128i *)x[i].bsums);
179 const __m128i q8sums_1 = _mm_loadu_si128((
const __m128i *)x[i].bsums + 1);
182 const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
183 const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
184 const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
185 const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
187 __m128i sumi_0 = _mm_setzero_si128();
188 __m128i sumi_1 = _mm_setzero_si128();
193 for (
int j = 0; j <
QK_K / 128; ++j) {
195 const __m128i q4bitsH_0 = _mm_loadu_si128((
const __m128i *)qh);
197 const __m128i q4bitsH_1 = _mm_loadu_si128((
const __m128i *)qh);
201 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
202 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
203 const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
204 const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
205 const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
206 const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
207 const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
208 const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
211 const __m128i q4bits1_0 = _mm_loadu_si128((
const __m128i *)ql);
213 const __m128i q4bits1_1 = _mm_loadu_si128((
const __m128i *)ql);
215 const __m128i q4bits2_0 = _mm_loadu_si128((
const __m128i *)ql);
217 const __m128i q4bits2_1 = _mm_loadu_si128((
const __m128i *)ql);
221 const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
222 const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
223 const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
224 const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
225 const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
226 const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
227 const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
228 const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
231 const __m128i q8_0 = _mm_loadu_si128((
const __m128i *)q8);
233 const __m128i q8_1 = _mm_loadu_si128((
const __m128i *)q8);
235 const __m128i q8_2 = _mm_loadu_si128((
const __m128i *)q8);
237 const __m128i q8_3 = _mm_loadu_si128((
const __m128i *)q8);
239 const __m128i q8_4 = _mm_loadu_si128((
const __m128i *)q8);
241 const __m128i q8_5 = _mm_loadu_si128((
const __m128i *)q8);
243 const __m128i q8_6 = _mm_loadu_si128((
const __m128i *)q8);
245 const __m128i q8_7 = _mm_loadu_si128((
const __m128i *)q8);
249 __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
250 __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
251 __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
252 __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
253 __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
254 __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
255 __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
256 __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
259 const __m128i scale_0 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 0]));
260 const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 1]));
261 const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 2]));
262 const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 3]));
266 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
267 p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
268 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
269 p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
270 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
271 p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
272 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
273 p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
276 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
277 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
278 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
279 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
283 sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
284 sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
287 __m128i sumi = _mm_add_epi32(sumi_0, sumi_1);
288 __m128 sumf_vec = _mm_mul_ps(_mm_set1_ps(d), _mm_cvtepi32_ps(sumi));
291 sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
292 sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
293 acc = _mm_add_ss(acc, sumf_vec);
296 return _mm_cvtss_f32(acc);
304 if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
310 const int blocks_per_row = K /
QK_K;
312 for (
int row = 0; row < M; ++row) {
313 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
314 y[row] = dot_q6_k_q8_k_sse(w_row, x, K);
326 #if defined(__AVX__) && !defined(__AVX2__)
328 static float dot_q6_k_q8_k_avx(
const block_q6_K *w,
332 const int nb = K /
QK_K;
333 const __m128i m3 = _mm_set1_epi8(3);
334 const __m128i m15 = _mm_set1_epi8(15);
336 __m128 acc = _mm_setzero_ps();
338 for (
int i = 0; i < nb; ++i) {
341 const uint8_t *ql = w[i].
ql;
342 const uint8_t *qh = w[i].
qh;
343 const int8_t *q8 = x[i].
qs;
347 _mm_prefetch((
const char *)&w[i + 1], _MM_HINT_T0);
348 _mm_prefetch((
const char *)&x[i + 1], _MM_HINT_T0);
352 const __m128i scales = _mm_loadu_si128((
const __m128i *)w[i].scales);
353 const __m128i q8sums_0 = _mm_loadu_si128((
const __m128i *)x[i].bsums);
354 const __m128i q8sums_1 = _mm_loadu_si128((
const __m128i *)x[i].bsums + 1);
357 const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
358 const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
359 const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
360 const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
362 __m128i sumi_0 = _mm_setzero_si128();
363 __m128i sumi_1 = _mm_setzero_si128();
368 for (
int j = 0; j <
QK_K / 128; ++j) {
370 const __m128i q4bitsH_0 = _mm_loadu_si128((
const __m128i *)qh);
372 const __m128i q4bitsH_1 = _mm_loadu_si128((
const __m128i *)qh);
376 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
377 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
378 const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
379 const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
380 const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
381 const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
382 const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
383 const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
386 const __m128i q4bits1_0 = _mm_loadu_si128((
const __m128i *)ql);
388 const __m128i q4bits1_1 = _mm_loadu_si128((
const __m128i *)ql);
390 const __m128i q4bits2_0 = _mm_loadu_si128((
const __m128i *)ql);
392 const __m128i q4bits2_1 = _mm_loadu_si128((
const __m128i *)ql);
396 const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
397 const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
398 const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
399 const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
400 const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
401 const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
402 const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
403 const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
406 const __m128i q8_0 = _mm_loadu_si128((
const __m128i *)q8);
408 const __m128i q8_1 = _mm_loadu_si128((
const __m128i *)q8);
410 const __m128i q8_2 = _mm_loadu_si128((
const __m128i *)q8);
412 const __m128i q8_3 = _mm_loadu_si128((
const __m128i *)q8);
414 const __m128i q8_4 = _mm_loadu_si128((
const __m128i *)q8);
416 const __m128i q8_5 = _mm_loadu_si128((
const __m128i *)q8);
418 const __m128i q8_6 = _mm_loadu_si128((
const __m128i *)q8);
420 const __m128i q8_7 = _mm_loadu_si128((
const __m128i *)q8);
424 __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
425 __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
426 __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
427 __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
428 __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
429 __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
430 __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
431 __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
434 const __m128i scale_0 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 0]));
435 const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 1]));
436 const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 2]));
437 const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)q6k_scale_shuffle[is + 3]));
441 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
442 p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
443 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
444 p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
445 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
446 p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
447 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
448 p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
451 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
452 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
453 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
454 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
458 sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
459 sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
462 __m128i sumi = _mm_add_epi32(sumi_0, sumi_1);
463 __m128 sumf_vec = _mm_mul_ps(_mm_set1_ps(d), _mm_cvtepi32_ps(sumi));
466 sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
467 sumf_vec = _mm_hadd_ps(sumf_vec, sumf_vec);
468 acc = _mm_add_ss(acc, sumf_vec);
471 return _mm_cvtss_f32(acc);
479 if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
485 const int blocks_per_row = K /
QK_K;
487 for (
int row = 0; row < M; ++row) {
488 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
489 y[row] = dot_q6_k_q8_k_avx(w_row, x, K);
499 #if defined(__AVX2__)
502 static const int8_t q6k_scale_shuffle_avx2[4][32] = {
503 { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
504 { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
505 { 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5 },
506 { 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 },
509 static inline __m128i get_scale_shuffle_avx2(
int i) {
510 static const uint8_t patterns[8][16] = {
511 { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1 },
512 { 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 },
513 { 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5 },
514 { 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7 },
515 { 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9 },
516 {10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11 },
517 {12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13 },
518 {14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15 },
520 return _mm_loadu_si128((
const __m128i *)patterns[i]);
523 static float dot_q6_k_q8_k_avx2(
const block_q6_K *w,
527 const int nb = K /
QK_K;
528 const __m256i m4 = _mm256_set1_epi8(0xF);
529 const __m256i m2 = _mm256_set1_epi8(3);
530 const __m256i m32s = _mm256_set1_epi8(32);
532 __m256 acc = _mm256_setzero_ps();
534 for (
int i = 0; i < nb; ++i) {
537 const uint8_t *q4 = w[i].
ql;
538 const uint8_t *qh = w[i].
qh;
539 const int8_t *q8 = x[i].
qs;
541 const __m128i scales = _mm_loadu_si128((
const __m128i *)w[i].scales);
543 __m256i sumi = _mm256_setzero_si256();
546 for (
int j = 0; j <
QK_K / 128; ++j) {
547 const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 0));
548 const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 1));
549 const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 2));
550 const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle_avx2(is + 3));
553 const __m256i q4bits1 = _mm256_loadu_si256((
const __m256i *)q4);
555 const __m256i q4bits2 = _mm256_loadu_si256((
const __m256i *)q4);
557 const __m256i q4bitsH = _mm256_loadu_si256((
const __m256i *)qh);
560 const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
561 const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
562 const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
563 const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
565 const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
566 const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
567 const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
568 const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
570 const __m256i q8_0 = _mm256_loadu_si256((
const __m256i *)q8);
572 const __m256i q8_1 = _mm256_loadu_si256((
const __m256i *)q8);
574 const __m256i q8_2 = _mm256_loadu_si256((
const __m256i *)q8);
576 const __m256i q8_3 = _mm256_loadu_si256((
const __m256i *)q8);
580 __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
581 __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
582 __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
583 __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
586 __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
587 __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
588 __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
589 __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
592 p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
593 p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
594 p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
595 p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
598 p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
599 p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
600 p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
601 p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
603 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
604 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
607 acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
611 __m128 hi = _mm256_extractf128_ps(acc, 1);
612 __m128 lo = _mm256_castps256_ps128(acc);
613 __m128 sum128 = _mm_add_ps(hi, lo);
614 sum128 = _mm_hadd_ps(sum128, sum128);
615 sum128 = _mm_hadd_ps(sum128, sum128);
616 return _mm_cvtss_f32(sum128);
624 if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
630 const int blocks_per_row = K /
QK_K;
632 for (
int row = 0; row < M; ++row) {
633 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
634 y[row] = dot_q6_k_q8_k_avx2(w_row, x, K);
646 #if defined(__AVX512F__) && defined(__AVX512BW__) && defined(__AVX512VBMI__)
653 static float dot_q6_k_q8_k_avx512_vbmi(
const block_q6_K *w,
657 const int nb = K /
QK_K;
658 const __m512i m4 = _mm512_set1_epi8(0xF);
659 const __m512i m2 = _mm512_set1_epi8(3);
660 const __m512i m32s = _mm512_set1_epi8(32);
662 __m512 acc = _mm512_setzero_ps();
664 for (
int i = 0; i < nb; ++i) {
667 const uint8_t *ql = w[i].
ql;
668 const uint8_t *qh = w[i].
qh;
669 const int8_t *q8 = x[i].
qs;
670 const int8_t *sc = w[i].
scales;
672 __m512i sumi = _mm512_setzero_si512();
676 const __m512i q4bits1 = _mm512_loadu_si512((
const __m512i *)ql);
677 const __m512i q4bits2 = _mm512_loadu_si512((
const __m512i *)(ql + 64));
680 const __m512i q4bitsH = _mm512_loadu_si512((
const __m512i *)qh);
684 const __m512i q4h_0 = _mm512_slli_epi16(_mm512_and_si512(q4bitsH, m2), 4);
686 const __m512i q4h_1 = _mm512_slli_epi16(_mm512_and_si512(_mm512_srli_epi16(q4bitsH, 2), m2), 4);
688 const __m512i q4h_2 = _mm512_slli_epi16(_mm512_and_si512(_mm512_srli_epi16(q4bitsH, 4), m2), 4);
690 const __m512i q4h_3 = _mm512_slli_epi16(_mm512_and_si512(_mm512_srli_epi16(q4bitsH, 6), m2), 4);
694 const __m512i q6_0 = _mm512_or_si512(_mm512_and_si512(q4bits1, m4), q4h_0);
695 const __m512i q6_1 = _mm512_or_si512(_mm512_and_si512(q4bits2, m4), q4h_1);
697 const __m512i q6_2 = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(q4bits1, 4), m4), q4h_2);
698 const __m512i q6_3 = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(q4bits2, 4), m4), q4h_3);
701 const __m512i q8_0 = _mm512_loadu_si512((
const __m512i *)q8);
702 const __m512i q8_1 = _mm512_loadu_si512((
const __m512i *)(q8 + 64));
703 const __m512i q8_2 = _mm512_loadu_si512((
const __m512i *)(q8 + 128));
704 const __m512i q8_3 = _mm512_loadu_si512((
const __m512i *)(q8 + 192));
707 __m512i q8s_0 = _mm512_maddubs_epi16(m32s, q8_0);
708 __m512i q8s_1 = _mm512_maddubs_epi16(m32s, q8_1);
709 __m512i q8s_2 = _mm512_maddubs_epi16(m32s, q8_2);
710 __m512i q8s_3 = _mm512_maddubs_epi16(m32s, q8_3);
713 __m512i p16_0 = _mm512_maddubs_epi16(q6_0, q8_0);
714 __m512i p16_1 = _mm512_maddubs_epi16(q6_1, q8_1);
715 __m512i p16_2 = _mm512_maddubs_epi16(q6_2, q8_2);
716 __m512i p16_3 = _mm512_maddubs_epi16(q6_3, q8_3);
719 p16_0 = _mm512_sub_epi16(p16_0, q8s_0);
720 p16_1 = _mm512_sub_epi16(p16_1, q8s_1);
721 p16_2 = _mm512_sub_epi16(p16_2, q8s_2);
722 p16_3 = _mm512_sub_epi16(p16_3, q8s_3);
727 const __m128i scales_128 = _mm_loadu_si128((
const __m128i *)sc);
731 const __m512i scale_idx_0 = _mm512_set_epi8(
732 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,
733 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
734 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,
735 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);
736 const __m512i scale_idx_1 = _mm512_set_epi8(
737 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,
738 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,
739 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,
740 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4);
741 const __m512i scale_idx_2 = _mm512_set_epi8(
742 11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,11,
743 10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,
744 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,
745 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8);
746 const __m512i scale_idx_3 = _mm512_set_epi8(
747 15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,
748 14,14,14,14,14,14,14,14,14,14,14,14,14,14,14,14,
749 13,13,13,13,13,13,13,13,13,13,13,13,13,13,13,13,
750 12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12);
753 const __m512i scales_512 = _mm512_broadcast_i32x4(scales_128);
754 const __m512i sc_0 = _mm512_permutexvar_epi8(scale_idx_0, scales_512);
755 const __m512i sc_1 = _mm512_permutexvar_epi8(scale_idx_1, scales_512);
756 const __m512i sc_2 = _mm512_permutexvar_epi8(scale_idx_2, scales_512);
757 const __m512i sc_3 = _mm512_permutexvar_epi8(scale_idx_3, scales_512);
761 __m512i p32_0 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_0)), p16_0);
762 __m512i p32_1 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_1)), p16_1);
763 __m512i p32_2 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_2)), p16_2);
764 __m512i p32_3 = _mm512_madd_epi16(_mm512_cvtepi8_epi16(_mm512_castsi512_si256(sc_3)), p16_3);
767 sumi = _mm512_add_epi32(sumi, p32_0);
768 sumi = _mm512_add_epi32(sumi, p32_1);
769 sumi = _mm512_add_epi32(sumi, p32_2);
770 sumi = _mm512_add_epi32(sumi, p32_3);
773 acc = _mm512_fmadd_ps(_mm512_set1_ps(d), _mm512_cvtepi32_ps(sumi), acc);
776 return _mm512_reduce_add_ps(acc);
784 if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
790 const int blocks_per_row = K /
QK_K;
792 for (
int row = 0; row < M; ++row) {
793 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
794 y[row] = dot_q6_k_q8_k_avx512_vbmi(w_row, x, K);
800 #if defined(__AVX512F__) && defined(__AVX512BW__)
809 static float dot_q6_k_q8_k_avx512(
const block_q6_K *w,
813 const int nb = K /
QK_K;
814 const __m256i m4 = _mm256_set1_epi8(0xF);
815 const __m256i m2 = _mm256_set1_epi8(3);
816 const __m256i m32s = _mm256_set1_epi8(32);
819 __m256 acc = _mm256_setzero_ps();
821 for (
int i = 0; i < nb; ++i) {
824 const uint8_t *q4 = w[i].
ql;
825 const uint8_t *qh = w[i].
qh;
826 const int8_t *q8 = x[i].
qs;
828 const __m128i scales = _mm_loadu_si128((
const __m128i *)w[i].scales);
831 __m256i sumi = _mm256_setzero_si256();
835 for (
int j = 0; j <
QK_K / 128; ++j) {
837 static const uint8_t patterns[8][16] = {
838 { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1 },
839 { 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3 },
840 { 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5 },
841 { 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7 },
842 { 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9 },
843 {10,10,10,10,10,10,10,10,11,11,11,11,11,11,11,11 },
844 {12,12,12,12,12,12,12,12,13,13,13,13,13,13,13,13 },
845 {14,14,14,14,14,14,14,14,15,15,15,15,15,15,15,15 },
848 const __m128i scale_0 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)patterns[is + 0]));
849 const __m128i scale_1 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)patterns[is + 1]));
850 const __m128i scale_2 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)patterns[is + 2]));
851 const __m128i scale_3 = _mm_shuffle_epi8(scales, _mm_loadu_si128((
const __m128i *)patterns[is + 3]));
855 const __m256i q4bits1 = _mm256_loadu_si256((
const __m256i *)q4);
857 const __m256i q4bits2 = _mm256_loadu_si256((
const __m256i *)q4);
859 const __m256i q4bitsH = _mm256_loadu_si256((
const __m256i *)qh);
863 const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
864 const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
865 const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
866 const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
869 const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
870 const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
871 const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
872 const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
875 const __m256i q8_0 = _mm256_loadu_si256((
const __m256i *)q8);
877 const __m256i q8_1 = _mm256_loadu_si256((
const __m256i *)q8);
879 const __m256i q8_2 = _mm256_loadu_si256((
const __m256i *)q8);
881 const __m256i q8_3 = _mm256_loadu_si256((
const __m256i *)q8);
885 __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
886 __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
887 __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
888 __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
891 __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
892 __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
893 __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
894 __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
897 p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
898 p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
899 p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
900 p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
903 p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
904 p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
905 p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
906 p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
909 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
910 sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
914 acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), acc);
918 __m128 hi = _mm256_extractf128_ps(acc, 1);
919 __m128 lo = _mm256_castps256_ps128(acc);
920 __m128 sum128 = _mm_add_ps(hi, lo);
921 sum128 = _mm_hadd_ps(sum128, sum128);
922 sum128 = _mm_hadd_ps(sum128, sum128);
923 return _mm_cvtss_f32(sum128);
931 if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
937 const int blocks_per_row = K /
QK_K;
939 for (
int row = 0; row < M; ++row) {
940 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
941 y[row] = dot_q6_k_q8_k_avx512(w_row, x, K);
956 if (!s || !vx || !vy || n <= 0) {
964 #if defined(__AVX512F__) && defined(__AVX512BW__)
965 *s = dot_q6_k_q8_k_avx512(x, y, n);
966 #elif defined(__AVX2__)
967 *s = dot_q6_k_q8_k_avx2(x, y, n);
968 #elif defined(__AVX__) && !defined(__AVX2__)
969 *s = dot_q6_k_q8_k_avx(x, y, n);
970 #elif defined(__SSSE3__)
971 *s = dot_q6_k_q8_k_sse(x, y, n);
986 #if defined(__AVX512F__) && defined(__AVX512BW__)
988 #elif defined(__AVX2__)
990 #elif defined(__AVX__)
992 #elif defined(__SSSE3__)
1020 if (!y || !W || !x_q8 || M <= 0 || K <= 0)
return;
1021 if (ith < 0 || nth <= 0 || ith >= nth)
return;
1024 const int dr = (M + nth - 1) / nth;
1025 const int r0 = dr * ith;
1026 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1028 if (r0 >= M)
return;
1032 const int blocks_per_row = K /
QK_K;
1034 for (
int row = r0; row < r1; ++row) {
1035 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
1052 if (!y || !W || !x_q8 || M <= 0 || K <= 0)
return;
1053 if (ith < 0 || nth <= 0 || ith >= nth)
return;
1055 const int dr = (M + nth - 1) / nth;
1056 const int r0 = dr * ith;
1057 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1059 if (r0 >= M)
return;
1063 const int blocks_per_row = K /
QK_K;
1065 #if defined(__AVX__) || defined(__SSSE3__)
1067 const int PREFETCH_ROWS = 4;
1068 for (
int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1069 const char *row_ptr = (
const char *)(blocks + (r0 + p) * blocks_per_row);
1070 _mm_prefetch(row_ptr, _MM_HINT_T0);
1071 _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1074 for (
int row = r0; row < r1; ++row) {
1076 if (row + PREFETCH_ROWS < r1) {
1077 const char *prefetch_ptr = (
const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1078 _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
1079 _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
1082 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
1083 #if defined(__AVX2__)
1084 y[row] = dot_q6_k_q8_k_avx2(w_row, x, K);
1085 #elif defined(__AVX__)
1086 y[row] = dot_q6_k_q8_k_avx(w_row, x, K);
1088 y[row] = dot_q6_k_q8_k_sse(w_row, x, K);
1093 for (
int row = r0; row < r1; ++row) {
1094 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
1113 int M,
int N,
int K)
1115 if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
1120 const int blocks_per_vec = K /
QK_K;
1122 for (
int n = 0; n < N; ++n) {
1123 const block_q8_K *x_row = X + (size_t)n * (
size_t)blocks_per_vec;
1148 int M,
int N,
int K)
1150 if (!A_q8 || !B || !
C) {
1153 if (M <= 0 || N <= 0 || K <= 0) {
1163 for (
int i = 0; i < M; ++i) {
1164 float *row =
C + (size_t)i * (
size_t)N;
1165 for (
int j = 0; j < N; ++j) {
Quantization block structures for weight-only quantization.
#define GGML_FP16_TO_FP32
void gemv_q6_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemm_q6_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
void gemv_q6_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void vec_dot_q6_k_q8_k(int n, float *s, const void *vx, const void *vy)
Q6_K x Q8_K dot product (single row)
void gemv_q6_k_q8_k_avx512_vbmi(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q6_K × Q8_K.
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
void gemv_q6_k_q8_k_avx512(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q6_K × Q8_K.
static float dot_q6_k_q8_k_ref(const block_q6_K *w, const block_q8_K *x, int K)
Scalar dot product for Q6_K x Q8_K.