37 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
38 #include <immintrin.h>
70 const int blocks_per_row = K /
QK5_0;
72 for (
int row = 0; row < M; row++) {
75 for (
int b = 0; b < blocks_per_row; b++) {
76 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
78 const float *xp = &x[b *
QK5_0];
82 memcpy(&qh, block->
qh,
sizeof(qh));
89 for (
int j = 0; j <
QK5_0 / 2; j++) {
90 const uint8_t packed = block->
qs[j];
93 const int lo = (packed & 0x0F);
94 const int hi = (packed >> 4);
97 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
98 const int xh_1 = ((qh >> (j + 12))) & 0x10;
101 const int q0 = (lo | xh_0) - 16;
102 const int q1 = (hi | xh_1) - 16;
105 sum += d * (float)q0 * xp[j];
106 sum += d * (float)q1 * xp[j + 16];
118 void gemv_q5_0_avx512(
float *y,
124 const int blocks_per_row = K /
QK5_0;
125 const __m512i offset = _mm512_set1_epi32(16);
126 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
127 const __m512i one = _mm512_set1_epi32(1);
129 for (
int row = 0; row < M; row++) {
130 __m512 acc = _mm512_setzero_ps();
132 for (
int b = 0; b < blocks_per_row; b++) {
133 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
135 const float *xp = &x[b *
QK5_0];
139 memcpy(&qh, block->
qh,
sizeof(qh));
142 __m128i packed = _mm_loadu_si128((
const __m128i *)block->
qs);
143 __m512i bytes = _mm512_cvtepu8_epi32(packed);
146 __m512i lo = _mm512_and_epi32(bytes, mask_lo);
147 __m512i hi_shift = _mm512_srli_epi32(bytes, 4);
154 __m512i qh_lo = _mm512_set_epi32(
155 ((qh >> 15) & 1) << 4, ((qh >> 14) & 1) << 4,
156 ((qh >> 13) & 1) << 4, ((qh >> 12) & 1) << 4,
157 ((qh >> 11) & 1) << 4, ((qh >> 10) & 1) << 4,
158 ((qh >> 9) & 1) << 4, ((qh >> 8) & 1) << 4,
159 ((qh >> 7) & 1) << 4, ((qh >> 6) & 1) << 4,
160 ((qh >> 5) & 1) << 4, ((qh >> 4) & 1) << 4,
161 ((qh >> 3) & 1) << 4, ((qh >> 2) & 1) << 4,
162 ((qh >> 1) & 1) << 4, ((qh >> 0) & 1) << 4
169 __m512i qh_hi = _mm512_set_epi32(
170 ((qh >> 31) & 1) << 4, ((qh >> 30) & 1) << 4,
171 ((qh >> 29) & 1) << 4, ((qh >> 28) & 1) << 4,
172 ((qh >> 27) & 1) << 4, ((qh >> 26) & 1) << 4,
173 ((qh >> 25) & 1) << 4, ((qh >> 24) & 1) << 4,
174 ((qh >> 23) & 1) << 4, ((qh >> 22) & 1) << 4,
175 ((qh >> 21) & 1) << 4, ((qh >> 20) & 1) << 4,
176 ((qh >> 19) & 1) << 4, ((qh >> 18) & 1) << 4,
177 ((qh >> 17) & 1) << 4, ((qh >> 16) & 1) << 4
181 __m512i q_lo = _mm512_sub_epi32(_mm512_or_epi32(lo, qh_lo), offset);
182 __m512i q_hi = _mm512_sub_epi32(_mm512_or_epi32(hi_shift, qh_hi), offset);
185 __m512 w_lo = _mm512_mul_ps(_mm512_cvtepi32_ps(q_lo), vscale);
186 __m512 w_hi = _mm512_mul_ps(_mm512_cvtepi32_ps(q_hi), vscale);
189 __m512 x_first = _mm512_loadu_ps(&xp[0]);
190 __m512 x_second = _mm512_loadu_ps(&xp[16]);
192 acc = _mm512_fmadd_ps(w_lo, x_first, acc);
193 acc = _mm512_fmadd_ps(w_hi, x_second, acc);
196 y[row] = _mm512_reduce_add_ps(acc);
213 #if defined(__AVX2__) && !defined(__AVX512F__)
216 static inline float hsum_avx2(__m256 v) {
217 __m128 lo = _mm256_castps256_ps128(v);
218 __m128 hi = _mm256_extractf128_ps(v, 1);
219 lo = _mm_add_ps(lo, hi);
220 __m128 shuf = _mm_shuffle_ps(lo, lo, _MM_SHUFFLE(2, 3, 0, 1));
221 __m128 sums = _mm_add_ps(lo, shuf);
222 shuf = _mm_movehl_ps(shuf, sums);
223 sums = _mm_add_ss(sums, shuf);
224 return _mm_cvtss_f32(sums);
233 void gemv_q5_0_avx2(
float *y,
239 const int blocks_per_row = K /
QK5_0;
241 for (
int row = 0; row < M; row++) {
242 __m256 acc = _mm256_setzero_ps();
244 for (
int b = 0; b < blocks_per_row; b++) {
245 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
247 const __m256 vscale = _mm256_set1_ps(d);
248 const float *xp = &x[b *
QK5_0];
252 memcpy(&qh, block->
qh,
sizeof(qh));
263 __m128i qs8 = _mm_loadl_epi64((
const __m128i *)block->
qs);
264 __m128i lo = _mm_and_si128(qs8, _mm_set1_epi8(0x0F));
268 for (
int i = 0; i < 8; i++) {
269 hb[i] = ((qh >> i) << 4) & 0x10;
271 __m128i hi = _mm_loadl_epi64((
const __m128i *)hb);
274 __m128i q5 = _mm_or_si128(lo, hi);
275 __m128i offset = _mm_set1_epi8(16);
276 __m128i q5_signed = _mm_sub_epi8(q5, offset);
279 __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
280 __m256 wf = _mm256_cvtepi32_ps(q32);
281 wf = _mm256_mul_ps(wf, vscale);
284 __m256 xv = _mm256_loadu_ps(&xp[0]);
285 acc = _mm256_fmadd_ps(wf, xv, acc);
290 __m128i qs8 = _mm_loadl_epi64((
const __m128i *)(block->
qs + 8));
291 __m128i lo = _mm_and_si128(qs8, _mm_set1_epi8(0x0F));
294 for (
int i = 0; i < 8; i++) {
295 hb[i] = ((qh >> (8 + i)) << 4) & 0x10;
297 __m128i hi = _mm_loadl_epi64((
const __m128i *)hb);
299 __m128i q5 = _mm_or_si128(lo, hi);
300 __m128i offset = _mm_set1_epi8(16);
301 __m128i q5_signed = _mm_sub_epi8(q5, offset);
303 __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
304 __m256 wf = _mm256_cvtepi32_ps(q32);
305 wf = _mm256_mul_ps(wf, vscale);
307 __m256 xv = _mm256_loadu_ps(&xp[8]);
308 acc = _mm256_fmadd_ps(wf, xv, acc);
313 __m128i qs8 = _mm_loadl_epi64((
const __m128i *)block->
qs);
314 __m128i hi_nib = _mm_and_si128(_mm_srli_epi16(qs8, 4), _mm_set1_epi8(0x0F));
318 for (
int i = 0; i < 8; i++) {
319 hb[i] = ((qh >> (16 + i)) & 1) << 4;
321 __m128i hi = _mm_loadl_epi64((
const __m128i *)hb);
323 __m128i q5 = _mm_or_si128(hi_nib, hi);
324 __m128i offset = _mm_set1_epi8(16);
325 __m128i q5_signed = _mm_sub_epi8(q5, offset);
327 __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
328 __m256 wf = _mm256_cvtepi32_ps(q32);
329 wf = _mm256_mul_ps(wf, vscale);
331 __m256 xv = _mm256_loadu_ps(&xp[16]);
332 acc = _mm256_fmadd_ps(wf, xv, acc);
337 __m128i qs8 = _mm_loadl_epi64((
const __m128i *)(block->
qs + 8));
338 __m128i hi_nib = _mm_and_si128(_mm_srli_epi16(qs8, 4), _mm_set1_epi8(0x0F));
341 for (
int i = 0; i < 8; i++) {
342 hb[i] = ((qh >> (24 + i)) & 1) << 4;
344 __m128i hi = _mm_loadl_epi64((
const __m128i *)hb);
346 __m128i q5 = _mm_or_si128(hi_nib, hi);
347 __m128i offset = _mm_set1_epi8(16);
348 __m128i q5_signed = _mm_sub_epi8(q5, offset);
350 __m256i q32 = _mm256_cvtepi8_epi32(q5_signed);
351 __m256 wf = _mm256_cvtepi32_ps(q32);
352 wf = _mm256_mul_ps(wf, vscale);
354 __m256 xv = _mm256_loadu_ps(&xp[24]);
355 acc = _mm256_fmadd_ps(wf, xv, acc);
359 y[row] = hsum_avx2(acc);
382 #if defined(__AVX__) && !defined(__AVX2__) && !defined(__AVX512F__)
385 static inline __m128i extract_low_nibbles(__m128i packed) {
386 return _mm_and_si128(packed, _mm_set1_epi8(0x0F));
390 static inline __m128i extract_high_nibbles(__m128i packed) {
391 return _mm_and_si128(_mm_srli_epi16(packed, 4), _mm_set1_epi8(0x0F));
395 static inline float hsum_sse(__m128 v) {
396 __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1));
397 __m128 sums = _mm_add_ps(v, shuf);
398 shuf = _mm_movehl_ps(shuf, sums);
399 sums = _mm_add_ss(sums, shuf);
400 return _mm_cvtss_f32(sums);
404 static inline float dot_int8_float8_sse(__m128i q8_lo,
const float *x,
float scale) {
406 __m128i lo16 = _mm_cvtepi8_epi16(q8_lo);
407 __m128i lo32_0 = _mm_cvtepi16_epi32(lo16);
408 __m128i lo32_1 = _mm_cvtepi16_epi32(_mm_srli_si128(lo16, 8));
411 __m128 w0 = _mm_cvtepi32_ps(lo32_0);
412 __m128 w1 = _mm_cvtepi32_ps(lo32_1);
415 __m128 vscale = _mm_set1_ps(scale);
416 w0 = _mm_mul_ps(w0, vscale);
417 w1 = _mm_mul_ps(w1, vscale);
420 __m128 x0 = _mm_loadu_ps(x);
421 __m128 x1 = _mm_loadu_ps(x + 4);
423 __m128 prod0 = _mm_mul_ps(w0, x0);
424 __m128 prod1 = _mm_mul_ps(w1, x1);
427 __m128 sum = _mm_add_ps(prod0, prod1);
428 return hsum_sse(sum);
437 void gemv_q5_0_avx(
float *y,
443 const int blocks_per_row = K /
QK5_0;
445 const __m128i mask_0f = _mm_set1_epi8(0x0F);
446 const __m128i mask_10 = _mm_set1_epi8(0x10);
448 for (
int row = 0; row < M; row++) {
451 for (
int b = 0; b < blocks_per_row; b++) {
452 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
454 const float *xp = &x[b *
QK5_0];
457 __m128i qs = _mm_loadu_si128((
const __m128i *)block->
qs);
460 __m128i lo_nibbles = _mm_and_si128(qs, mask_0f);
461 __m128i hi_nibbles = _mm_and_si128(_mm_srli_epi16(qs, 4), mask_0f);
465 memcpy(&qh, block->
qh,
sizeof(qh));
478 for (
int i = 0; i < 8; i++) {
479 int lo = block->
qs[i] & 0x0F;
480 int hb = ((qh >> i) << 4) & 0x10;
481 w8[i] = (lo | hb) - 16;
483 __m128i q8 = _mm_loadl_epi64((
const __m128i *)w8);
484 sum += dot_int8_float8_sse(q8, &xp[0], d);
490 for (
int i = 0; i < 8; i++) {
491 int lo = block->
qs[8 + i] & 0x0F;
492 int hb = ((qh >> (8 + i)) << 4) & 0x10;
493 w8[i] = (lo | hb) - 16;
495 __m128i q8 = _mm_loadl_epi64((
const __m128i *)w8);
496 sum += dot_int8_float8_sse(q8, &xp[8], d);
502 for (
int i = 0; i < 8; i++) {
503 int hi = block->
qs[i] >> 4;
504 int hb = (qh >> (12 + i)) & 0x10;
505 w8[i] = (hi | hb) - 16;
507 __m128i q8 = _mm_loadl_epi64((
const __m128i *)w8);
508 sum += dot_int8_float8_sse(q8, &xp[16], d);
514 for (
int i = 0; i < 8; i++) {
515 int hi = block->
qs[8 + i] >> 4;
516 int hb = (qh >> (20 + i)) & 0x10;
517 w8[i] = (hi | hb) - 16;
519 __m128i q8 = _mm_loadl_epi64((
const __m128i *)w8);
520 sum += dot_int8_float8_sse(q8, &xp[24], d);
553 #if defined(__AVX512F__)
554 gemv_q5_0_avx512(y, W, x, M, K);
555 #elif defined(__AVX2__)
556 gemv_q5_0_avx2(y, W, x, M, K);
557 #elif defined(__AVX__)
558 gemv_q5_0_avx(y, W, x, M, K);
559 #elif defined(__SSE4_1__)
560 gemv_q5_0_sse_v2(y, W, x, M, K);
582 if (!y || !W || !x || M <= 0 || K <= 0)
return;
583 if (ith < 0 || nth <= 0 || ith >= nth)
return;
585 const int dr = (M + nth - 1) / nth;
586 const int r0 = dr * ith;
587 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
592 const int blocks_per_row = K /
QK5_0;
594 for (
int row = r0; row < r1; row++) {
596 for (
int b = 0; b < blocks_per_row; b++) {
597 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
599 const float *xp = &x[b *
QK5_0];
602 memcpy(&qh, block->
qh,
sizeof(qh));
604 for (
int j = 0; j <
QK5_0 / 2; j++) {
605 const uint8_t packed = block->
qs[j];
606 const int lo = (packed & 0x0F);
607 const int hi = (packed >> 4);
608 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
609 const int xh_1 = ((qh >> (j + 12))) & 0x10;
610 const int w0 = (lo | xh_0) - 16;
611 const int w1 = (hi | xh_1) - 16;
612 sum += d * (w0 * xp[j] + w1 * xp[j +
QK5_0/2]);
628 if (!y || !W || !x || M <= 0 || K <= 0)
return;
629 if (ith < 0 || nth <= 0 || ith >= nth)
return;
631 const int dr = (M + nth - 1) / nth;
632 const int r0 = dr * ith;
633 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
638 const int blocks_per_row = K /
QK5_0;
640 #if defined(__AVX__) || defined(__SSE4_1__)
642 const int PREFETCH_ROWS = 4;
643 for (
int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
644 const char *row_ptr = (
const char *)(blocks + (r0 + p) * blocks_per_row);
645 _mm_prefetch(row_ptr, _MM_HINT_T0);
646 _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
649 for (
int row = r0; row < r1; ++row) {
651 if (row + PREFETCH_ROWS < r1) {
652 const char *prefetch_ptr = (
const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
653 _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
654 _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
658 #if defined(__AVX512F__)
660 gemv_q5_0_avx512(&y[row], (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q5_0), x, 1, K);
661 #elif defined(__AVX2__)
662 gemv_q5_0_avx2(&y[row], (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q5_0), x, 1, K);
663 #elif defined(__AVX__)
664 gemv_q5_0_avx(&y[row], (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q5_0), x, 1, K);
666 gemv_q5_0_sse_v2(&y[row], (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q5_0), x, 1, K);
687 for (
int n = 0; n < N; n++) {
688 gemv_q5_0(&Y[n * M], W, &X[n * K], M, K);
711 const int blocks_per_row = K /
QK5_0;
714 memset(dX, 0, K *
sizeof(
float));
717 for (
int row = 0; row < M; row++) {
718 const float dy = dY[row];
720 for (
int b = 0; b < blocks_per_row; b++) {
721 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
723 float *dxp = &dX[b *
QK5_0];
727 memcpy(&qh, block->
qh,
sizeof(qh));
730 for (
int j = 0; j <
QK5_0 / 2; j++) {
731 const uint8_t packed = block->
qs[j];
734 const int lo = (packed & 0x0F);
735 const int hi = (packed >> 4);
736 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
737 const int xh_1 = ((qh >> (j + 12))) & 0x10;
738 const int q0 = (lo | xh_0) - 16;
739 const int q1 = (hi | xh_1) - 16;
741 dxp[j] += d * (float)q0 * dy;
742 dxp[j + 16] += d * (float)q1 * dy;
767 for (
int n = 0; n < N; n++) {
795 const int blocks_per_row = K /
QK5_0;
797 for (
int m = 0; m < M; m++) {
798 const float *a_row = &A[m * K];
800 for (
int n = 0; n < N; n++) {
803 for (
int b = 0; b < blocks_per_row; b++) {
804 const block_q5_0 *block = &blocks[n * blocks_per_row + b];
806 const float *ap = &a_row[b *
QK5_0];
809 memcpy(&qh, block->
qh,
sizeof(qh));
812 for (
int j = 0; j <
QK5_0 / 2; j++) {
813 const uint8_t packed = block->
qs[j];
814 const int lo = (packed & 0x0F);
815 const int hi = (packed >> 4);
816 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
817 const int xh_1 = ((qh >> (j + 12))) & 0x10;
818 const int q0 = (lo | xh_0) - 16;
819 const int q1 = (hi | xh_1) - 16;
821 sum += d * (float)q0 * ap[j];
822 sum += d * (float)q1 * ap[j + 16];
826 C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
843 for (
int n = 0; n < N; n++) {
857 for (
int m = 0; m < M; m++) {
858 float *row =
C + (size_t)m * (
size_t)N;
859 for (
int n = 0; n < N; n++) {
870 float dot_q5_0(
const void *w_q5_0,
const float *x,
int K)
901 const int qk =
QK5_0;
902 const int nb = n / qk;
909 for (
int ib = 0; ib < nb; ib++) {
912 memcpy(&qh, x[ib].qh,
sizeof(qh));
917 for (
int j = 0; j < qk / 2; j++) {
919 const uint8_t xh_0 = ((qh & (1u << (j + 0))) >> (j + 0)) << 4;
920 const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
923 const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
924 const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
927 sumi0 += x0 * y[ib].
qs[j];
928 sumi1 += x1 * y[ib].
qs[j + qk / 2];
931 int sumi = sumi0 + sumi1;
942 void vec_dot_q5_0_q8_0_avx512(
int n,
float *s,
const void *vx,
const void *vy)
944 const int qk =
QK5_0;
945 const int nb = n / qk;
952 for (
int ib = 0; ib < nb; ib++) {
957 memcpy(&qh, x[ib].qh,
sizeof(qh));
960 __m128i qs = _mm_loadu_si128((
const __m128i *)x[ib].qs);
963 __m512i lo_nibbles = _mm512_cvtepu8_epi32(qs);
964 lo_nibbles = _mm512_and_epi32(lo_nibbles, _mm512_set1_epi32(0x0F));
967 __m512i qh_lo = _mm512_set_epi32(
968 ((qh >> 15) & 1) << 4, ((qh >> 14) & 1) << 4,
969 ((qh >> 13) & 1) << 4, ((qh >> 12) & 1) << 4,
970 ((qh >> 11) & 1) << 4, ((qh >> 10) & 1) << 4,
971 ((qh >> 9) & 1) << 4, ((qh >> 8) & 1) << 4,
972 ((qh >> 7) & 1) << 4, ((qh >> 6) & 1) << 4,
973 ((qh >> 5) & 1) << 4, ((qh >> 4) & 1) << 4,
974 ((qh >> 3) & 1) << 4, ((qh >> 2) & 1) << 4,
975 ((qh >> 1) & 1) << 4, ((qh >> 0) & 1) << 4
979 __m512i q5_lo = _mm512_sub_epi32(_mm512_or_epi32(lo_nibbles, qh_lo),
980 _mm512_set1_epi32(16));
983 __m128i y8_lo = _mm_loadu_si128((
const __m128i *)&y[ib].qs[0]);
984 __m512i y32_lo = _mm512_cvtepi8_epi32(y8_lo);
987 __m512i prod_lo = _mm512_mullo_epi32(q5_lo, y32_lo);
990 __m512i hi_nibbles = _mm512_cvtepu8_epi32(qs);
991 hi_nibbles = _mm512_srli_epi32(hi_nibbles, 4);
994 __m512i qh_hi = _mm512_set_epi32(
995 ((qh >> 31) & 1) << 4, ((qh >> 30) & 1) << 4,
996 ((qh >> 29) & 1) << 4, ((qh >> 28) & 1) << 4,
997 ((qh >> 27) & 1) << 4, ((qh >> 26) & 1) << 4,
998 ((qh >> 25) & 1) << 4, ((qh >> 24) & 1) << 4,
999 ((qh >> 23) & 1) << 4, ((qh >> 22) & 1) << 4,
1000 ((qh >> 21) & 1) << 4, ((qh >> 20) & 1) << 4,
1001 ((qh >> 19) & 1) << 4, ((qh >> 18) & 1) << 4,
1002 ((qh >> 17) & 1) << 4, ((qh >> 16) & 1) << 4
1005 __m512i q5_hi = _mm512_sub_epi32(_mm512_or_epi32(hi_nibbles, qh_hi),
1006 _mm512_set1_epi32(16));
1009 __m128i y8_hi = _mm_loadu_si128((
const __m128i *)&y[ib].qs[16]);
1010 __m512i y32_hi = _mm512_cvtepi8_epi32(y8_hi);
1012 __m512i prod_hi = _mm512_mullo_epi32(q5_hi, y32_hi);
1015 int sumi = _mm512_reduce_add_epi32(_mm512_add_epi32(prod_lo, prod_hi));
1018 sumf += d * (float)sumi;
1025 #if defined(__SSSE3__)
1033 static inline void bytes_from_bits_32_sse(__m128i *out_lo, __m128i *out_hi,
const uint8_t *qh)
1036 memcpy(&x32, qh,
sizeof(uint32_t));
1039 const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101LL, 0x0000000000000000LL);
1040 const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303LL, 0x0202020202020202LL);
1042 __m128i bytes_lo = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
1043 __m128i bytes_hi = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
1048 const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL);
1050 bytes_lo = _mm_or_si128(bytes_lo, bit_mask);
1051 bytes_hi = _mm_or_si128(bytes_hi, bit_mask);
1054 *out_lo = _mm_cmpeq_epi8(bytes_lo, _mm_set1_epi64x(-1LL));
1055 *out_hi = _mm_cmpeq_epi8(bytes_hi, _mm_set1_epi64x(-1LL));
1064 static inline __m128i mul_sum_i8_pairs_sse(
const __m128i x,
const __m128i y)
1066 const __m128i ax = _mm_sign_epi8(x, x);
1067 const __m128i sy = _mm_sign_epi8(y, x);
1068 const __m128i dot = _mm_maddubs_epi16(ax, sy);
1069 return _mm_madd_epi16(dot, _mm_set1_epi16(1));
1084 void vec_dot_q5_0_q8_0_sse(
int n,
float *s,
const void *vx,
const void *vy)
1086 const int qk =
QK5_0;
1087 const int nb = n / qk;
1094 const __m128i mask_0f = _mm_set1_epi8(0x0F);
1095 const __m128i mask_f0 = _mm_set1_epi8((
char)0xF0);
1097 for (
int ib = 0; ib < nb; ib++) {
1101 __m128i qs = _mm_loadu_si128((
const __m128i *)x[ib].qs);
1104 __m128i bx_lo = _mm_and_si128(qs, mask_0f);
1105 __m128i bx_hi = _mm_and_si128(_mm_srli_epi16(qs, 4), mask_0f);
1108 __m128i bxhi_lo, bxhi_hi;
1109 bytes_from_bits_32_sse(&bxhi_lo, &bxhi_hi, x[ib].qh);
1114 bxhi_lo = _mm_andnot_si128(bxhi_lo, mask_f0);
1115 bxhi_hi = _mm_andnot_si128(bxhi_hi, mask_f0);
1118 bx_lo = _mm_or_si128(bx_lo, bxhi_lo);
1119 bx_hi = _mm_or_si128(bx_hi, bxhi_hi);
1122 __m128i by_lo = _mm_loadu_si128((
const __m128i *)y[ib].qs);
1123 __m128i by_hi = _mm_loadu_si128((
const __m128i *)(y[ib].qs + 16));
1126 __m128i p_lo = mul_sum_i8_pairs_sse(bx_lo, by_lo);
1127 __m128i p_hi = mul_sum_i8_pairs_sse(bx_hi, by_hi);
1130 __m128i sum = _mm_add_epi32(p_lo, p_hi);
1133 __m128i hi64 = _mm_unpackhi_epi64(sum, sum);
1134 __m128i sum64 = _mm_add_epi32(hi64, sum);
1135 __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1136 int32_t sumi = _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
1139 sumf += d * (float)sumi;
1146 #if defined(__AVX__) && !defined(__AVX512F__)
1149 #define MM256_SET_M128I(hi, lo) _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1)
1155 static inline __m256i bytes_from_bits_32_avx(
const uint8_t *qh)
1158 memcpy(&x32, qh,
sizeof(uint32_t));
1160 const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101LL, 0x0000000000000000LL);
1161 const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303LL, 0x0202020202020202LL);
1163 __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
1164 __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
1166 const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL);
1168 bytesl = _mm_or_si128(bytesl, bit_mask);
1169 bytesh = _mm_or_si128(bytesh, bit_mask);
1171 bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1LL));
1172 bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1LL));
1174 return MM256_SET_M128I(bytesh, bytesl);
1180 static inline __m256i bytes_from_nibbles_32_avx(
const uint8_t *qs)
1182 __m128i tmpl = _mm_loadu_si128((
const __m128i *)qs);
1183 __m128i tmph = _mm_srli_epi16(tmpl, 4);
1184 const __m128i lowMask = _mm_set1_epi8(0x0F);
1185 tmpl = _mm_and_si128(lowMask, tmpl);
1186 tmph = _mm_and_si128(lowMask, tmph);
1187 return MM256_SET_M128I(tmph, tmpl);
1194 static inline __m256 mul_sum_i8_pairs_float_avx(
const __m256i x,
const __m256i y)
1196 const __m128i xl = _mm256_castsi256_si128(x);
1197 const __m128i xh = _mm256_extractf128_si256(x, 1);
1198 const __m128i yl = _mm256_castsi256_si128(y);
1199 const __m128i yh = _mm256_extractf128_si256(y, 1);
1202 const __m128i axl = _mm_sign_epi8(xl, xl);
1203 const __m128i axh = _mm_sign_epi8(xh, xh);
1205 const __m128i syl = _mm_sign_epi8(yl, xl);
1206 const __m128i syh = _mm_sign_epi8(yh, xh);
1209 const __m128i dotl = _mm_maddubs_epi16(axl, syl);
1210 const __m128i doth = _mm_maddubs_epi16(axh, syh);
1213 const __m128i ones = _mm_set1_epi16(1);
1214 const __m128i summed_pairsl = _mm_madd_epi16(ones, dotl);
1215 const __m128i summed_pairsh = _mm_madd_epi16(ones, doth);
1218 const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
1219 return _mm256_cvtepi32_ps(summed_pairs);
1225 static inline float hsum_float_8_avx(
const __m256 x)
1227 __m128 res = _mm256_extractf128_ps(x, 1);
1228 res = _mm_add_ps(res, _mm256_castps256_ps128(x));
1229 res = _mm_add_ps(res, _mm_movehl_ps(res, res));
1230 res = _mm_add_ss(res, _mm_movehdup_ps(res));
1231 return _mm_cvtss_f32(res);
1245 void vec_dot_q5_0_q8_0_avx(
int n,
float *s,
const void *vx,
const void *vy)
1247 const int qk =
QK5_0;
1248 const int nb = n / qk;
1253 __m256 acc0 = _mm256_setzero_ps();
1254 __m256 acc1 = _mm256_setzero_ps();
1255 const __m128i
mask = _mm_set1_epi8((
char)0xF0);
1259 for (; ib + 1 < nb; ib += 2) {
1261 _mm_prefetch((
const char *)&x[ib + 4], _MM_HINT_T0);
1262 _mm_prefetch((
const char *)&y[ib + 4], _MM_HINT_T0);
1268 __m256i bx0 = bytes_from_nibbles_32_avx(x[ib].qs);
1271 const __m256i bxhi0 = bytes_from_bits_32_avx(x[ib].qh);
1272 __m128i bxhil0 = _mm256_castsi256_si128(bxhi0);
1273 __m128i bxhih0 = _mm256_extractf128_si256(bxhi0, 1);
1278 __m256i bx1 = bytes_from_nibbles_32_avx(x[ib+1].qs);
1279 const __m256i bxhi1 = bytes_from_bits_32_avx(x[ib+1].qh);
1280 __m128i bxhil1 = _mm256_castsi256_si128(bxhi1);
1281 __m128i bxhih1 = _mm256_extractf128_si256(bxhi1, 1);
1284 bxhil0 = _mm_andnot_si128(bxhil0,
mask);
1285 bxhih0 = _mm_andnot_si128(bxhih0,
mask);
1287 __m128i bxl0 = _mm256_castsi256_si128(bx0);
1288 __m128i bxh0 = _mm256_extractf128_si256(bx0, 1);
1289 bxl0 = _mm_or_si128(bxl0, bxhil0);
1290 bxh0 = _mm_or_si128(bxh0, bxhih0);
1291 bx0 = MM256_SET_M128I(bxh0, bxl0);
1293 const __m256i by0 = _mm256_loadu_si256((
const __m256i *)y[ib].qs);
1294 const __m256 q0 = mul_sum_i8_pairs_float_avx(bx0, by0);
1295 acc0 = _mm256_add_ps(_mm256_mul_ps(d0, q0), acc0);
1298 bxhil1 = _mm_andnot_si128(bxhil1,
mask);
1299 bxhih1 = _mm_andnot_si128(bxhih1,
mask);
1301 __m128i bxl1 = _mm256_castsi256_si128(bx1);
1302 __m128i bxh1 = _mm256_extractf128_si256(bx1, 1);
1303 bxl1 = _mm_or_si128(bxl1, bxhil1);
1304 bxh1 = _mm_or_si128(bxh1, bxhih1);
1305 bx1 = MM256_SET_M128I(bxh1, bxl1);
1307 const __m256i by1 = _mm256_loadu_si256((
const __m256i *)y[ib+1].qs);
1308 const __m256 q1 = mul_sum_i8_pairs_float_avx(bx1, by1);
1309 acc1 = _mm256_add_ps(_mm256_mul_ps(d1, q1), acc1);
1313 for (; ib < nb; ib++) {
1316 __m256i bx_0 = bytes_from_nibbles_32_avx(x[ib].qs);
1317 const __m256i bxhi = bytes_from_bits_32_avx(x[ib].qh);
1318 __m128i bxhil = _mm256_castsi256_si128(bxhi);
1319 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1321 bxhil = _mm_andnot_si128(bxhil,
mask);
1322 bxhih = _mm_andnot_si128(bxhih,
mask);
1324 __m128i bxl = _mm256_castsi256_si128(bx_0);
1325 __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
1326 bxl = _mm_or_si128(bxl, bxhil);
1327 bxh = _mm_or_si128(bxh, bxhih);
1328 bx_0 = MM256_SET_M128I(bxh, bxl);
1330 const __m256i by_0 = _mm256_loadu_si256((
const __m256i *)y[ib].qs);
1331 const __m256 q = mul_sum_i8_pairs_float_avx(bx_0, by_0);
1332 acc0 = _mm256_add_ps(_mm256_mul_ps(d, q), acc0);
1336 acc0 = _mm256_add_ps(acc0, acc1);
1337 *s = hsum_float_8_avx(acc0);
1369 int M,
int N,
int K)
1371 const int nb = K /
QK5_0;
1374 const __m128i
mask = _mm_set1_epi8((
char)0xF0);
1376 for (
int m = 0; m < M; m++) {
1377 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
1381 for (; n + 1 < N; n += 2) {
1382 const block_q5_0 *w0 = b_blocks + (size_t)(n + 0) * nb;
1383 const block_q5_0 *w1 = b_blocks + (size_t)(n + 1) * nb;
1385 __m256 acc_n0 = _mm256_setzero_ps();
1386 __m256 acc_n1 = _mm256_setzero_ps();
1388 for (
int ib = 0; ib < nb; ib++) {
1391 _mm_prefetch((
const char *)&w0[ib + 2], _MM_HINT_T0);
1392 _mm_prefetch((
const char *)&w1[ib + 2], _MM_HINT_T0);
1396 const __m256i by = _mm256_loadu_si256((
const __m256i *)a_row[ib].qs);
1404 __m256i bx = bytes_from_nibbles_32_avx(w0[ib].qs);
1407 const __m256i bxhi = bytes_from_bits_32_avx(w0[ib].qh);
1408 __m128i bxhil = _mm256_castsi256_si128(bxhi);
1409 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1412 bxhil = _mm_andnot_si128(bxhil,
mask);
1413 bxhih = _mm_andnot_si128(bxhih,
mask);
1414 __m128i bxl = _mm256_castsi256_si128(bx);
1415 __m128i bxh = _mm256_extractf128_si256(bx, 1);
1416 bxl = _mm_or_si128(bxl, bxhil);
1417 bxh = _mm_or_si128(bxh, bxhih);
1418 bx = MM256_SET_M128I(bxh, bxl);
1421 const __m256 q = mul_sum_i8_pairs_float_avx(bx, by);
1422 acc_n0 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), q), acc_n0);
1429 __m256i bx = bytes_from_nibbles_32_avx(w1[ib].qs);
1430 const __m256i bxhi = bytes_from_bits_32_avx(w1[ib].qh);
1431 __m128i bxhil = _mm256_castsi256_si128(bxhi);
1432 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1434 bxhil = _mm_andnot_si128(bxhil,
mask);
1435 bxhih = _mm_andnot_si128(bxhih,
mask);
1436 __m128i bxl = _mm256_castsi256_si128(bx);
1437 __m128i bxh = _mm256_extractf128_si256(bx, 1);
1438 bxl = _mm_or_si128(bxl, bxhil);
1439 bxh = _mm_or_si128(bxh, bxhih);
1440 bx = MM256_SET_M128I(bxh, bxl);
1442 const __m256 q = mul_sum_i8_pairs_float_avx(bx, by);
1443 acc_n1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), q), acc_n1);
1448 float s0 = hsum_float_8_avx(acc_n0);
1449 float s1 = hsum_float_8_avx(acc_n1);
1450 if (bias) { s0 += bias[n]; s1 += bias[n + 1]; }
1451 C[(size_t)m * N + n] = s0;
1452 C[(size_t)m * N + n + 1] = s1;
1456 for (; n < N; n++) {
1457 const block_q5_0 *w = b_blocks + (size_t)n * nb;
1458 __m256 acc = _mm256_setzero_ps();
1460 for (
int ib = 0; ib < nb; ib++) {
1461 const __m256i by = _mm256_loadu_si256((
const __m256i *)a_row[ib].qs);
1464 __m256i bx = bytes_from_nibbles_32_avx(w[ib].qs);
1465 const __m256i bxhi = bytes_from_bits_32_avx(w[ib].qh);
1466 __m128i bxhil = _mm256_castsi256_si128(bxhi);
1467 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
1468 bxhil = _mm_andnot_si128(bxhil,
mask);
1469 bxhih = _mm_andnot_si128(bxhih,
mask);
1470 __m128i bxl = _mm256_castsi256_si128(bx);
1471 __m128i bxh = _mm256_extractf128_si256(bx, 1);
1472 bxl = _mm_or_si128(bxl, bxhil);
1473 bxh = _mm_or_si128(bxh, bxhih);
1474 bx = MM256_SET_M128I(bxh, bxl);
1476 const __m256 q = mul_sum_i8_pairs_float_avx(bx, by);
1477 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), q), acc);
1480 float s = hsum_float_8_avx(acc);
1481 if (bias) s += bias[n];
1482 C[(size_t)m * N + n] = s;
1500 #if defined(__AVX512F__)
1501 vec_dot_q5_0_q8_0_avx512(n, s, vx, vy);
1502 #elif defined(__AVX__)
1504 vec_dot_q5_0_q8_0_avx(n, s, vx, vy);
1505 #elif defined(__SSSE3__)
1507 vec_dot_q5_0_q8_0_sse(n, s, vx, vy);
1536 const int blocks_per_row = K /
QK5_0;
1538 for (
int row = 0; row < M; row++) {
1540 &w_blocks[row * blocks_per_row],
1557 if (!y || !W || !x_q8 || M <= 0 || K <= 0)
return;
1558 if (ith < 0 || nth <= 0 || ith >= nth)
return;
1560 const int dr = (M + nth - 1) / nth;
1561 const int r0 = dr * ith;
1562 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1564 if (r0 >= M)
return;
1568 const int blocks_per_row = K /
QK5_0;
1570 #if defined(__AVX__) || defined(__SSE4_1__)
1571 const int PREFETCH_ROWS = 4;
1572 for (
int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1573 const char *row_ptr = (
const char *)(w_blocks + (r0 + p) * blocks_per_row);
1574 _mm_prefetch(row_ptr, _MM_HINT_T0);
1575 _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1578 for (
int row = r0; row < r1; ++row) {
1579 if (row + PREFETCH_ROWS < r1) {
1580 const char *pf = (
const char *)(w_blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1581 _mm_prefetch(pf, _MM_HINT_T0);
1582 _mm_prefetch(pf + 64, _MM_HINT_T0);
1586 &w_blocks[row * blocks_per_row],
1590 for (
int row = r0; row < r1; row++) {
1592 &w_blocks[row * blocks_per_row],
1628 const int blocks_per_row = K /
QK5_0;
1630 for (
int m = 0; m < M; m++) {
1631 const block_q8_0 *input_row = &inputs[m * blocks_per_row];
1633 for (
int n = 0; n < N; n++) {
1634 const block_q5_0 *weight_row = &weights[n * blocks_per_row];
1635 float *out = &
C[m * N + n];
CPU feature detection and dispatch macros.
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void gemm_nt_q5_0_q8_0_unroll_avx(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q5_0_sse_v2(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_q5_0_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemv_q5_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void dequant_q5_0_block(const block_q5_0 *block, float *output)
Dequantize a single Q5_0 block to FP32.
void dequant_q5_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
void gemv_q5_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_0 weights (scalar reference)
void gemv_q5_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void vec_dot_q5_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void gemm_q5_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q5_0 weights.
void gemm_nt_q5_0_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_0 weights: C = A @ B^T.
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
void gemv_q5_0_parallel(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel reference GEMV for Q5_0 × FP32.
float dot_q5_0(const void *w_q5_0, const float *x, int K)
void gemv_q5_0_q8_0_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 x Q8_0 with prefetching.
void gemv_q5_0_parallel_simd(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.
int32_t int32_t int32_t int32_t int32_t mask