35 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
36 #include <immintrin.h>
62 const int nb = k /
QK8_0;
65 const __m256 sign_bit = _mm256_set1_ps(-0.0f);
66 const __m256 v_half = _mm256_set1_ps(0.5f);
67 const __m256 v_min = _mm256_set1_ps(-127.0f);
68 const __m256 v_max = _mm256_set1_ps(127.0f);
70 for (
int i = 0; i < nb; i++) {
71 __m256 v0 = _mm256_loadu_ps(x + 0);
72 __m256 v1 = _mm256_loadu_ps(x + 8);
73 __m256 v2 = _mm256_loadu_ps(x + 16);
74 __m256 v3 = _mm256_loadu_ps(x + 24);
77 __m256 max_abs = _mm256_andnot_ps(sign_bit, v0);
78 max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v1));
79 max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v2));
80 max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v3));
82 __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
83 _mm256_castps256_ps128(max_abs));
84 max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
85 max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
86 const float max_scalar = _mm_cvtss_f32(max4);
88 const float d = max_scalar / 127.0f;
89 const float id = max_scalar != 0.0f ? 127.0f / max_scalar : 0.0f;
92 const __m256 mul = _mm256_set1_ps(
id);
93 v0 = _mm256_mul_ps(v0, mul);
94 v1 = _mm256_mul_ps(v1, mul);
95 v2 = _mm256_mul_ps(v2, mul);
96 v3 = _mm256_mul_ps(v3, mul);
98 v0 = _mm256_min_ps(_mm256_max_ps(v0, v_min), v_max);
99 v1 = _mm256_min_ps(_mm256_max_ps(v1, v_min), v_max);
100 v2 = _mm256_min_ps(_mm256_max_ps(v2, v_min), v_max);
101 v3 = _mm256_min_ps(_mm256_max_ps(v3, v_min), v_max);
104 v0 = _mm256_add_ps(v0, _mm256_or_ps(_mm256_and_ps(v0, sign_bit), v_half));
105 v1 = _mm256_add_ps(v1, _mm256_or_ps(_mm256_and_ps(v1, sign_bit), v_half));
106 v2 = _mm256_add_ps(v2, _mm256_or_ps(_mm256_and_ps(v2, sign_bit), v_half));
107 v3 = _mm256_add_ps(v3, _mm256_or_ps(_mm256_and_ps(v3, sign_bit), v_half));
109 __m256i i0 = _mm256_cvttps_epi32(v0);
110 __m256i i1 = _mm256_cvttps_epi32(v1);
111 __m256i i2 = _mm256_cvttps_epi32(v2);
112 __m256i i3 = _mm256_cvttps_epi32(v3);
114 #if defined(__AVX2__)
115 i0 = _mm256_packs_epi32(i0, i1);
116 i2 = _mm256_packs_epi32(i2, i3);
117 i0 = _mm256_packs_epi16(i0, i2);
119 const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
120 i0 = _mm256_permutevar8x32_epi32(i0, perm);
121 _mm256_storeu_si256((__m256i *)y[i].qs, i0);
123 __m128i ni0 = _mm256_castsi256_si128(i0);
124 __m128i ni1 = _mm256_extractf128_si256(i0, 1);
125 __m128i ni2 = _mm256_castsi256_si128(i1);
126 __m128i ni3 = _mm256_extractf128_si256(i1, 1);
127 __m128i ni4 = _mm256_castsi256_si128(i2);
128 __m128i ni5 = _mm256_extractf128_si256(i2, 1);
129 __m128i ni6 = _mm256_castsi256_si128(i3);
130 __m128i ni7 = _mm256_extractf128_si256(i3, 1);
132 ni0 = _mm_packs_epi32(ni0, ni1);
133 ni2 = _mm_packs_epi32(ni2, ni3);
134 ni4 = _mm_packs_epi32(ni4, ni5);
135 ni6 = _mm_packs_epi32(ni6, ni7);
137 ni0 = _mm_packs_epi16(ni0, ni2);
138 ni4 = _mm_packs_epi16(ni4, ni6);
140 _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
141 _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
145 for (
int i = 0; i < nb; i++) {
146 const float *xb = x + i *
QK8_0;
150 for (
int j = 0; j <
QK8_0; j++) {
151 float av = xb[j] >= 0 ? xb[j] : -xb[j];
152 if (av > amax) amax = av;
156 float d = amax / 127.0f;
157 float id = d != 0.0f ? 127.0f / amax : 0.0f;
163 for (
int j = 0; j <
QK8_0; j++) {
164 float v = xb[j] *
id;
166 int q = (int)(v + (v >= 0 ? 0.5f : -0.5f));
167 if (q > 127) q = 127;
168 if (q < -127) q = -127;
169 y[i].
qs[j] = (int8_t)q;
194 const size_t row_bytes_in = (size_t)k *
sizeof(
float);
197 uint8_t *out = (uint8_t *)vy;
198 const uint8_t *in = (
const uint8_t *)x;
200 for (
int row = 0; row < num_rows; ++row) {
202 (
const float *)(in + row * row_bytes_in),
203 (
void *)(out + row * row_bytes_out),
222 const size_t row_bytes_in = (size_t)k *
sizeof(
float);
225 const size_t row_bytes_out = (size_t)(k / 256) *
sizeof(
block_q8_K);
227 uint8_t *out = (uint8_t *)vy;
228 const uint8_t *in = (
const uint8_t *)x;
230 for (
int row = 0; row < num_rows; ++row) {
232 (
const float *)(in + row * row_bytes_in),
233 (
void *)(out + row * row_bytes_out),
258 const int blocks_per_row = K /
QK8_0;
260 for (
int row = 0; row < M; row++) {
263 for (
int b = 0; b < blocks_per_row; b++) {
264 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
266 const float *xp = &x[b *
QK8_0];
268 for (
int i = 0; i <
QK8_0; i++) {
269 sum += d * (float)block->
qs[i] * xp[i];
281 void gemv_q8_0_avx512(
float *y,
287 const int blocks_per_row = K /
QK8_0;
289 for (
int row = 0; row < M; row++) {
290 __m512 acc = _mm512_setzero_ps();
292 for (
int b = 0; b < blocks_per_row; b++) {
293 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
295 const float *xp = &x[b *
QK8_0];
298 for (
int chunk = 0; chunk < 2; chunk++) {
300 __m128i q8 = _mm_loadu_si128((
const __m128i *)&block->
qs[chunk * 16]);
303 __m512i q32 = _mm512_cvtepi8_epi32(q8);
306 __m512 w = _mm512_mul_ps(_mm512_cvtepi32_ps(q32), vscale);
309 __m512 x_vec = _mm512_loadu_ps(&xp[chunk * 16]);
312 acc = _mm512_fmadd_ps(w, x_vec, acc);
316 y[row] = _mm512_reduce_add_ps(acc);
333 #if defined(__AVX2__) && !defined(__AVX512F__)
336 static inline float hsum_avx2_q8(__m256 v) {
337 __m128 lo = _mm256_castps256_ps128(v);
338 __m128 hi = _mm256_extractf128_ps(v, 1);
339 lo = _mm_add_ps(lo, hi);
340 __m128 shuf = _mm_shuffle_ps(lo, lo, _MM_SHUFFLE(2, 3, 0, 1));
341 __m128 sums = _mm_add_ps(lo, shuf);
342 shuf = _mm_movehl_ps(shuf, sums);
343 sums = _mm_add_ss(sums, shuf);
344 return _mm_cvtss_f32(sums);
353 void gemv_q8_0_avx2(
float *y,
359 const int blocks_per_row = K /
QK8_0;
361 for (
int row = 0; row < M; row++) {
362 __m256 acc = _mm256_setzero_ps();
364 for (
int b = 0; b < blocks_per_row; b++) {
365 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
367 const __m256 vscale = _mm256_set1_ps(d);
368 const float *xp = &x[b *
QK8_0];
374 __m128i q8 = _mm_loadl_epi64((
const __m128i *)&block->
qs[0]);
375 __m256i q32 = _mm256_cvtepi8_epi32(q8);
376 __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
377 __m256 xv = _mm256_loadu_ps(&xp[0]);
378 acc = _mm256_fmadd_ps(wf, xv, acc);
383 __m128i q8 = _mm_loadl_epi64((
const __m128i *)&block->
qs[8]);
384 __m256i q32 = _mm256_cvtepi8_epi32(q8);
385 __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
386 __m256 xv = _mm256_loadu_ps(&xp[8]);
387 acc = _mm256_fmadd_ps(wf, xv, acc);
392 __m128i q8 = _mm_loadl_epi64((
const __m128i *)&block->
qs[16]);
393 __m256i q32 = _mm256_cvtepi8_epi32(q8);
394 __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
395 __m256 xv = _mm256_loadu_ps(&xp[16]);
396 acc = _mm256_fmadd_ps(wf, xv, acc);
401 __m128i q8 = _mm_loadl_epi64((
const __m128i *)&block->
qs[24]);
402 __m256i q32 = _mm256_cvtepi8_epi32(q8);
403 __m256 wf = _mm256_mul_ps(_mm256_cvtepi32_ps(q32), vscale);
404 __m256 xv = _mm256_loadu_ps(&xp[24]);
405 acc = _mm256_fmadd_ps(wf, xv, acc);
409 y[row] = hsum_avx2_q8(acc);
426 #if defined(__AVX__) && !defined(__AVX2__) && !defined(__AVX512F__)
429 static inline float hsum_sse_q8(__m128 v) {
430 __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1));
431 __m128 sums = _mm_add_ps(v, shuf);
432 shuf = _mm_movehl_ps(shuf, sums);
433 sums = _mm_add_ss(sums, shuf);
434 return _mm_cvtss_f32(sums);
443 void gemv_q8_0_avx(
float *y,
449 const int blocks_per_row = K /
QK8_0;
451 for (
int row = 0; row < M; row++) {
453 __m128 acc0 = _mm_setzero_ps();
454 __m128 acc1 = _mm_setzero_ps();
455 __m128 acc2 = _mm_setzero_ps();
456 __m128 acc3 = _mm_setzero_ps();
458 for (
int b = 0; b < blocks_per_row; b++) {
459 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
461 const float *xp = &x[b *
QK8_0];
462 const __m128 vscale = _mm_set1_ps(d);
465 __m128i q8_0 = _mm_loadu_si128((
const __m128i *)&block->
qs[0]);
466 __m128i q8_1 = _mm_loadu_si128((
const __m128i *)&block->
qs[16]);
471 __m128i q16 = _mm_cvtepi8_epi16(q8_0);
472 __m128i q32 = _mm_cvtepi16_epi32(q16);
473 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
474 __m128 vx = _mm_loadu_ps(&xp[0]);
475 acc0 = _mm_add_ps(acc0, _mm_mul_ps(w, vx));
480 __m128i q16 = _mm_cvtepi8_epi16(q8_0);
481 __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
482 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
483 __m128 vx = _mm_loadu_ps(&xp[4]);
484 acc1 = _mm_add_ps(acc1, _mm_mul_ps(w, vx));
489 __m128i q8_shifted = _mm_srli_si128(q8_0, 8);
490 __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
491 __m128i q32 = _mm_cvtepi16_epi32(q16);
492 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
493 __m128 vx = _mm_loadu_ps(&xp[8]);
494 acc2 = _mm_add_ps(acc2, _mm_mul_ps(w, vx));
499 __m128i q8_shifted = _mm_srli_si128(q8_0, 8);
500 __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
501 __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
502 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
503 __m128 vx = _mm_loadu_ps(&xp[12]);
504 acc3 = _mm_add_ps(acc3, _mm_mul_ps(w, vx));
510 __m128i q16 = _mm_cvtepi8_epi16(q8_1);
511 __m128i q32 = _mm_cvtepi16_epi32(q16);
512 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
513 __m128 vx = _mm_loadu_ps(&xp[16]);
514 acc0 = _mm_add_ps(acc0, _mm_mul_ps(w, vx));
519 __m128i q16 = _mm_cvtepi8_epi16(q8_1);
520 __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
521 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
522 __m128 vx = _mm_loadu_ps(&xp[20]);
523 acc1 = _mm_add_ps(acc1, _mm_mul_ps(w, vx));
528 __m128i q8_shifted = _mm_srli_si128(q8_1, 8);
529 __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
530 __m128i q32 = _mm_cvtepi16_epi32(q16);
531 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
532 __m128 vx = _mm_loadu_ps(&xp[24]);
533 acc2 = _mm_add_ps(acc2, _mm_mul_ps(w, vx));
538 __m128i q8_shifted = _mm_srli_si128(q8_1, 8);
539 __m128i q16 = _mm_cvtepi8_epi16(q8_shifted);
540 __m128i q32 = _mm_cvtepi16_epi32(_mm_srli_si128(q16, 8));
541 __m128 w = _mm_mul_ps(_mm_cvtepi32_ps(q32), vscale);
542 __m128 vx = _mm_loadu_ps(&xp[28]);
543 acc3 = _mm_add_ps(acc3, _mm_mul_ps(w, vx));
548 __m128 sum01 = _mm_add_ps(acc0, acc1);
549 __m128 sum23 = _mm_add_ps(acc2, acc3);
550 __m128 sum = _mm_add_ps(sum01, sum23);
552 y[row] = hsum_sse_q8(sum);
557 #if defined(__SSE4_1__)
558 #include <immintrin.h>
561 #define SSE_Q8_BLOCK(q8_reg, offset, xp, d_val, acc) do { \
562 __m128 vx = _mm_loadu_ps(&(xp)[offset]); \
563 __m128i qw = _mm_cvtepi8_epi32(_mm_srli_si128(q8_reg, offset)); \
564 __m128 vw = _mm_cvtepi32_ps(qw); \
565 acc = _mm_add_ps(acc, _mm_mul_ps(_mm_mul_ps(vw, vx), _mm_set1_ps(d_val))); \
568 void gemv_q8_0_sse(
float *y,
574 const int blocks_per_row = K /
QK8_0;
576 for (
int row = 0; row < M; row++) {
577 __m128 acc = _mm_setzero_ps();
579 for (
int b = 0; b < blocks_per_row; b++) {
580 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
582 const float *xp = &x[b *
QK8_0];
585 __m128i q8_0 = _mm_loadu_si128((
const __m128i *)&block->
qs[0]);
586 __m128i q8_1 = _mm_loadu_si128((
const __m128i *)&block->
qs[16]);
589 SSE_Q8_BLOCK(q8_0, 0, xp, d_val, acc);
590 SSE_Q8_BLOCK(q8_0, 4, xp, d_val, acc);
591 SSE_Q8_BLOCK(q8_0, 8, xp, d_val, acc);
592 SSE_Q8_BLOCK(q8_0, 12, xp, d_val, acc);
595 const float *xp1 = xp + 16;
596 SSE_Q8_BLOCK(q8_1, 0, xp1, d_val, acc);
597 SSE_Q8_BLOCK(q8_1, 4, xp1, d_val, acc);
598 SSE_Q8_BLOCK(q8_1, 8, xp1, d_val, acc);
599 SSE_Q8_BLOCK(q8_1, 12, xp1, d_val, acc);
603 acc = _mm_add_ps(acc, _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(1, 0, 3, 2)));
604 acc = _mm_add_ps(acc, _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(0, 1, 0, 1)));
605 _mm_store_ss(&y[row], acc);
636 #if defined(__AVX512F__)
637 gemv_q8_0_avx512(y, W, x, M, K);
638 #elif defined(__AVX2__)
639 gemv_q8_0_avx2(y, W, x, M, K);
640 #elif defined(__AVX__)
641 gemv_q8_0_avx(y, W, x, M, K);
642 #elif defined(__SSE4_1__)
643 gemv_q8_0_sse(y, W, x, M, K);
661 for (
int n = 0; n < N; n++) {
662 gemv_q8_0(&Y[n * M], W, &X[n * K], M, K);
688 for (
int m = 0; m < M; m++) {
691 for (
int n = 0; n < N; n++)
C[m * N + n] += bias[n];
697 const int blocks_per_row = K /
QK8_0;
699 for (
int m = 0; m < M; m++) {
700 const float *a_row = &A[m * K];
702 for (
int n = 0; n < N; n++) {
705 for (
int b = 0; b < blocks_per_row; b++) {
706 const block_q8_0 *block = &blocks[n * blocks_per_row + b];
708 const float *ap = &a_row[b *
QK8_0];
710 for (
int i = 0; i <
QK8_0; i++) {
711 sum += d * (float)block->
qs[i] * ap[i];
715 C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
739 const int blocks_per_row = K /
QK8_0;
742 memset(dX, 0, K *
sizeof(
float));
745 for (
int row = 0; row < M; row++) {
746 const float dy = dY[row];
748 for (
int b = 0; b < blocks_per_row; b++) {
749 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
751 float *dxp = &dX[b *
QK8_0];
753 for (
int i = 0; i <
QK8_0; i++) {
754 dxp[i] += d * (float)block->
qs[i] * dy;
764 void gemv_q8_0_backward_avx512(
float *dX,
770 const int blocks_per_row = K /
QK8_0;
773 memset(dX, 0, K *
sizeof(
float));
775 for (
int row = 0; row < M; row++) {
776 const __m512 vdy = _mm512_set1_ps(dY[row]);
778 for (
int b = 0; b < blocks_per_row; b++) {
779 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
781 float *dxp = &dX[b *
QK8_0];
784 for (
int chunk = 0; chunk < 2; chunk++) {
786 __m128i q8 = _mm_loadu_si128((
const __m128i *)&block->
qs[chunk * 16]);
787 __m512i q32 = _mm512_cvtepi8_epi32(q8);
788 __m512 w = _mm512_mul_ps(_mm512_cvtepi32_ps(q32), vscale);
791 __m512 grad = _mm512_mul_ps(w, vdy);
794 __m512 dx_cur = _mm512_loadu_ps(&dxp[chunk * 16]);
795 _mm512_storeu_ps(&dxp[chunk * 16], _mm512_add_ps(dx_cur, grad));
811 gemv_q8_0_backward_avx512(dX, W, dY, M, K);
825 for (
int n = 0; n < N; n++) {
834 float dot_q8_0(
const void *w_q8_0,
const float *x,
int K)
865 const int qk =
QK8_0;
866 const int nb = n / qk;
873 for (
int ib = 0; ib < nb; ib++) {
876 for (
int j = 0; j < qk; j++) {
877 sumi += x[ib].
qs[j] * y[ib].
qs[j];
890 void vec_dot_q8_0_q8_0_avx512(
int n,
float *s,
const void *vx,
const void *vy)
892 const int qk =
QK8_0;
893 const int nb = n / qk;
900 for (
int ib = 0; ib < nb; ib++) {
904 __m128i x8_lo = _mm_loadu_si128((
const __m128i *)&x[ib].qs[0]);
905 __m128i x8_hi = _mm_loadu_si128((
const __m128i *)&x[ib].qs[16]);
906 __m128i y8_lo = _mm_loadu_si128((
const __m128i *)&y[ib].qs[0]);
907 __m128i y8_hi = _mm_loadu_si128((
const __m128i *)&y[ib].qs[16]);
910 __m512i x32_lo = _mm512_cvtepi8_epi32(x8_lo);
911 __m512i x32_hi = _mm512_cvtepi8_epi32(x8_hi);
912 __m512i y32_lo = _mm512_cvtepi8_epi32(y8_lo);
913 __m512i y32_hi = _mm512_cvtepi8_epi32(y8_hi);
916 __m512i prod_lo = _mm512_mullo_epi32(x32_lo, y32_lo);
917 __m512i prod_hi = _mm512_mullo_epi32(x32_hi, y32_hi);
920 int sumi = _mm512_reduce_add_epi32(_mm512_add_epi32(prod_lo, prod_hi));
923 sumf += d * (float)sumi;
930 #if defined(__AVX__) && !defined(__AVX512F__)
934 void vec_dot_q8_0_q8_0_avx(
int n,
float *s,
const void *vx,
const void *vy)
936 const int qk =
QK8_0;
937 const int nb = n / qk;
944 for (
int ib = 0; ib < nb; ib++) {
950 for (
int j = 0; j < qk; j++) {
951 sumi += x[ib].
qs[j] * y[ib].
qs[j];
954 sumf += d * (float)sumi;
961 #if defined(__SSE4_1__) && !defined(__AVX__)
965 void vec_dot_q8_0_q8_0_sse(
int n,
float *s,
const void *vx,
const void *vy)
967 const int qk =
QK8_0;
968 const int nb = n / qk;
975 for (
int ib = 0; ib < nb; ib++) {
978 __m128i acc_lo = _mm_setzero_si128();
979 __m128i acc_hi = _mm_setzero_si128();
982 for (
int j = 0; j < 32; j += 8) {
984 __m128i x8 = _mm_loadl_epi64((
const __m128i *)&x[ib].qs[j]);
985 __m128i y8 = _mm_loadl_epi64((
const __m128i *)&y[ib].qs[j]);
988 __m128i x16 = _mm_cvtepi8_epi16(x8);
989 __m128i y16 = _mm_cvtepi8_epi16(y8);
992 __m128i prod = _mm_madd_epi16(x16, y16);
995 acc_lo = _mm_add_epi32(acc_lo, prod);
999 acc_lo = _mm_add_epi32(acc_lo, _mm_shuffle_epi32(acc_lo, _MM_SHUFFLE(1, 0, 3, 2)));
1000 acc_lo = _mm_add_epi32(acc_lo, _mm_shuffle_epi32(acc_lo, _MM_SHUFFLE(0, 1, 0, 1)));
1001 int sumi = _mm_extract_epi32(acc_lo, 0);
1003 sumf += d * (float)sumi;
1016 vec_dot_q8_0_q8_0_avx512(n, s, vx, vy);
1017 #elif defined(__AVX__)
1018 vec_dot_q8_0_q8_0_avx(n, s, vx, vy);
1019 #elif defined(__SSE4_1__)
1020 vec_dot_q8_0_q8_0_sse(n, s, vx, vy);
1049 const int blocks_per_row = K /
QK8_0;
1051 for (
int row = 0; row < M; row++) {
1053 &w_blocks[row * blocks_per_row],
1074 if (!y || !W || !x_q8 || M <= 0 || K <= 0)
return;
1075 if (ith < 0 || nth <= 0 || ith >= nth)
return;
1077 const int dr = (M + nth - 1) / nth;
1078 const int r0 = dr * ith;
1079 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1081 if (r0 >= M)
return;
1085 const int blocks_per_row = K /
QK8_0;
1087 for (
int row = r0; row < r1; row++) {
1089 &w_blocks[row * blocks_per_row],
1106 if (!y || !W || !x_q8 || M <= 0 || K <= 0)
return;
1107 if (ith < 0 || nth <= 0 || ith >= nth)
return;
1109 const int dr = (M + nth - 1) / nth;
1110 const int r0 = dr * ith;
1111 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1113 if (r0 >= M)
return;
1117 const int blocks_per_row = K /
QK8_0;
1119 #if defined(__AVX__) || defined(__SSE4_1__)
1121 const int PREFETCH_ROWS = 4;
1122 for (
int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1123 const char *row_ptr = (
const char *)(w_blocks + (r0 + p) * blocks_per_row);
1124 _mm_prefetch(row_ptr, _MM_HINT_T0);
1125 _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1128 for (
int row = r0; row < r1; ++row) {
1130 if (row + PREFETCH_ROWS < r1) {
1131 const char *pf = (
const char *)(w_blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1132 _mm_prefetch(pf, _MM_HINT_T0);
1133 _mm_prefetch(pf + 64, _MM_HINT_T0);
1137 &w_blocks[row * blocks_per_row],
1142 for (
int row = r0; row < r1; row++) {
1144 &w_blocks[row * blocks_per_row],
1159 if (!y || !W || !x || M <= 0 || K <= 0)
return;
1160 if (ith < 0 || nth <= 0 || ith >= nth)
return;
1162 const int dr = (M + nth - 1) / nth;
1163 const int r0 = dr * ith;
1164 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1166 if (r0 >= M)
return;
1169 const int blocks_per_row = K /
QK8_0;
1171 #if defined(__AVX__) || defined(__SSE4_1__)
1172 const int PREFETCH_ROWS = 4;
1173 for (
int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1174 const char *row_ptr = (
const char *)(blocks + (r0 + p) * blocks_per_row);
1175 _mm_prefetch(row_ptr, _MM_HINT_T0);
1176 _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1179 for (
int row = r0; row < r1; ++row) {
1180 if (row + PREFETCH_ROWS < r1) {
1181 const char *pf = (
const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1182 _mm_prefetch(pf, _MM_HINT_T0);
1183 _mm_prefetch(pf + 64, _MM_HINT_T0);
1187 #if defined(__AVX512F__)
1188 gemv_q8_0_avx512(&y[row],
1189 (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q8_0),
1191 #elif defined(__AVX2__)
1192 gemv_q8_0_avx2(&y[row],
1193 (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q8_0),
1195 #elif defined(__AVX__)
1196 gemv_q8_0_avx(&y[row],
1197 (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q8_0),
1199 #elif defined(__SSE4_1__)
1200 gemv_q8_0_sse(&y[row],
1201 (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q8_0),
1205 (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q8_0),
1210 for (
int row = r0; row < r1; row++) {
1212 (
const char *)blocks + row * blocks_per_row *
sizeof(
block_q8_0),
CPU feature detection and dispatch macros.
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define CK_FP32_TO_FP16(x)
void gemv_q8_0_parallel_simd(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching.
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
void quantize_batch_q8_0(const float *x, void *vy, int num_rows, int k)
Batch quantize FP32 to Q8_0 format (row-major output)
void gemv_q8_0_q8_0_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching.
void quantize_batch_q8_k(const float *x, void *vy, int num_rows, int k)
Batch quantize FP32 to Q8_K format (row-major output)
void quantize_row_q8_k(const float *x, void *vy, int k)
void gemm_q8_0_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void vec_dot_q8_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)
void gemv_q8_0_q8_0_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q8_0 x Q8_0.
void gemm_q8_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q8_0 weights.
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
void gemv_q8_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
float dot_q8_0(const void *w_q8_0, const float *x, int K)
void gemv_q8_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q8_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q8_0 weights (scalar reference)