28 #include <immintrin.h>
31 __m128i shuf = _mm_shuffle_epi32(v, _MM_SHUFFLE(1, 0, 3, 2));
32 __m128i sums = _mm_add_epi32(v, shuf);
33 shuf = _mm_shuffle_epi32(sums, _MM_SHUFFLE(2, 3, 0, 1));
34 sums = _mm_add_epi32(sums, shuf);
35 return _mm_cvtsi128_si32(sums);
45 const int nb = K /
QK_K;
47 const __m128i mask_low = _mm_set1_epi8(0x0F);
48 const __m128i ones = _mm_set1_epi16(1);
50 for (
int row = 0; row < M; ++row) {
54 for (
int i = 0; i < nb; ++i) {
57 _mm_prefetch((
const char *)&x[i + 1], _MM_HINT_T0);
58 _mm_prefetch((
const char *)&bx[i + 1], _MM_HINT_T0);
62 uint8_t sc[8], m_val[8];
68 const uint8_t *q4 = x[i].
qs;
69 const int8_t *q8 = bx[i].
qs;
74 for (
int j = 0; j <
QK_K; j += 64) {
75 __m128i acc_lo = _mm_setzero_si128();
76 __m128i acc_hi = _mm_setzero_si128();
79 for (
int l = 0; l < 32; l += 16) {
81 __m128i q4_vec = _mm_loadu_si128((
const __m128i *)(q4 + l));
84 __m128i q4_lo = _mm_and_si128(q4_vec, mask_low);
85 __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_vec, 4), mask_low);
88 __m128i q8_lo_vec = _mm_loadu_si128((
const __m128i *)(q8 + j + l));
89 __m128i q8_hi_vec = _mm_loadu_si128((
const __m128i *)(q8 + j + 32 + l));
94 __m128i prod_lo = _mm_maddubs_epi16(q4_lo, q8_lo_vec);
95 __m128i prod_hi = _mm_maddubs_epi16(q4_hi, q8_hi_vec);
98 acc_lo = _mm_add_epi32(acc_lo, _mm_madd_epi16(prod_lo, ones));
99 acc_hi = _mm_add_epi32(acc_hi, _mm_madd_epi16(prod_hi, ones));
106 int32_t bsum_lo = (int32_t)bx[i].bsums[j / 16] +
107 (int32_t)bx[i].
bsums[j / 16 + 1];
108 int32_t bsum_hi = (int32_t)bx[i].bsums[(j + 32) / 16] +
109 (int32_t)bx[i].bsums[(j + 32) / 16 + 1];
111 sumf += d * (float)sc[is] * (
float)sum_q4q8_lo;
112 sumf -= dmin * (float)m_val[is] * (
float)bsum_lo;
113 sumf += d * (float)sc[is + 1] * (
float)sum_q4q8_hi;
114 sumf -= dmin * (float)m_val[is + 1] * (
float)bsum_hi;
142 if (!y || !W || !x_q8 || M <= 0 || K <= 0)
return;
143 if (ith < 0 || nth <= 0 || ith >= nth)
return;
146 const int dr = (M + nth - 1) / nth;
147 const int r0 = dr * ith;
148 const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
154 const int nb = K /
QK_K;
155 const size_t bytes_per_row = (size_t)nb *
sizeof(
block_q4_K);
157 const __m128i mask_low = _mm_set1_epi8(0x0F);
158 const __m128i ones = _mm_set1_epi16(1);
161 const int PREFETCH_ROWS = 4;
162 for (
int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
163 const char *row_ptr = (
const char *)(blocks + (r0 + p) * nb);
164 _mm_prefetch(row_ptr, _MM_HINT_T0);
165 _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
168 for (
int row = r0; row < r1; ++row) {
170 if (row + PREFETCH_ROWS < r1) {
171 const char *prefetch_ptr = (
const char *)(blocks + (row + PREFETCH_ROWS) * nb);
172 _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
173 _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
174 _mm_prefetch(prefetch_ptr + 128, _MM_HINT_T0);
180 for (
int i = 0; i < nb; ++i) {
183 _mm_prefetch((
const char *)&x[i + 1], _MM_HINT_T0);
187 uint8_t sc[8], m_val[8];
193 const uint8_t *q4 = x[i].
qs;
194 const int8_t *q8 = bx[i].
qs;
199 for (
int j = 0; j <
QK_K; j += 64) {
200 __m128i acc_lo = _mm_setzero_si128();
201 __m128i acc_hi = _mm_setzero_si128();
204 for (
int l = 0; l < 32; l += 16) {
206 __m128i q4_vec = _mm_loadu_si128((
const __m128i *)(q4 + l));
209 __m128i q4_lo = _mm_and_si128(q4_vec, mask_low);
210 __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_vec, 4), mask_low);
213 __m128i q8_lo_vec = _mm_loadu_si128((
const __m128i *)(q8 + j + l));
214 __m128i q8_hi_vec = _mm_loadu_si128((
const __m128i *)(q8 + j + 32 + l));
217 __m128i prod_lo = _mm_maddubs_epi16(q4_lo, q8_lo_vec);
218 __m128i prod_hi = _mm_maddubs_epi16(q4_hi, q8_hi_vec);
221 acc_lo = _mm_add_epi32(acc_lo, _mm_madd_epi16(prod_lo, ones));
222 acc_hi = _mm_add_epi32(acc_hi, _mm_madd_epi16(prod_hi, ones));
229 int32_t bsum_lo = (int32_t)bx[i].bsums[j / 16] +
230 (int32_t)bx[i].
bsums[j / 16 + 1];
231 int32_t bsum_hi = (int32_t)bx[i].bsums[(j + 32) / 16] +
232 (int32_t)bx[i].bsums[(j + 32) / 16 + 1];
234 sumf += d * (float)sc[is] * (
float)sum_q4q8_lo;
235 sumf -= dmin * (float)m_val[is] * (
float)bsum_lo;
236 sumf += d * (float)sc[is + 1] * (
float)sum_q4q8_hi;
237 sumf -= dmin * (float)m_val[is + 1] * (
float)bsum_hi;
261 int M,
int K,
int ith,
int nth);
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
static int32_t hsum_epi32_sse(__m128i v)