17 #include <immintrin.h>
26 __m128i shuf = _mm_shuffle_epi32(v, _MM_SHUFFLE(1, 0, 3, 2));
27 __m128i sums = _mm_add_epi32(v, shuf);
28 shuf = _mm_shuffle_epi32(sums, _MM_SHUFFLE(2, 3, 0, 1));
29 sums = _mm_add_epi32(sums, shuf);
30 return _mm_cvtsi128_si32(sums);
40 const int blocks_per_row = K /
QK_K;
42 const __m128i mask_low = _mm_set1_epi8(0x0F);
44 for (
int row = 0; row < M; ++row) {
46 const block_q4_K *w_row = blocks + row * blocks_per_row;
48 for (
int i = 0; i < blocks_per_row; ++i) {
53 uint8_t sc[8], m_val[8];
63 for (
int j = 0; j <
QK_K; j += 64) {
68 __m128i acc_lo = _mm_setzero_si128();
69 __m128i acc_hi = _mm_setzero_si128();
72 for (
int l = 0; l < 32; l += 16) {
74 __m128i q4_vec = _mm_loadu_si128((
const __m128i *)(b4->
qs + q_offset + l));
77 __m128i q4_lo = _mm_and_si128(q4_vec, mask_low);
80 __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_vec, 4), mask_low);
83 __m128i q8_lo_vec = _mm_loadu_si128((
const __m128i *)(b8->
qs + j + l));
84 __m128i q8_hi_vec = _mm_loadu_si128((
const __m128i *)(b8->
qs + j + 32 + l));
90 __m128i q4_lo_16_L = _mm_cvtepu8_epi16(q4_lo);
91 __m128i q8_lo_16_L = _mm_cvtepi8_epi16(q8_lo_vec);
92 __m128i prod_lo_L = _mm_madd_epi16(q4_lo_16_L, q8_lo_16_L);
93 acc_lo = _mm_add_epi32(acc_lo, prod_lo_L);
95 __m128i q4_lo_16_H = _mm_cvtepu8_epi16(_mm_srli_si128(q4_lo, 8));
96 __m128i q8_lo_16_H = _mm_cvtepi8_epi16(_mm_srli_si128(q8_lo_vec, 8));
97 __m128i prod_lo_H = _mm_madd_epi16(q4_lo_16_H, q8_lo_16_H);
98 acc_lo = _mm_add_epi32(acc_lo, prod_lo_H);
101 __m128i q4_hi_16_L = _mm_cvtepu8_epi16(q4_hi);
102 __m128i q8_hi_16_L = _mm_cvtepi8_epi16(q8_hi_vec);
103 __m128i prod_hi_L = _mm_madd_epi16(q4_hi_16_L, q8_hi_16_L);
104 acc_hi = _mm_add_epi32(acc_hi, prod_hi_L);
106 __m128i q4_hi_16_H = _mm_cvtepu8_epi16(_mm_srli_si128(q4_hi, 8));
107 __m128i q8_hi_16_H = _mm_cvtepi8_epi16(_mm_srli_si128(q8_hi_vec, 8));
108 __m128i prod_hi_H = _mm_madd_epi16(q4_hi_16_H, q8_hi_16_H);
109 acc_hi = _mm_add_epi32(acc_hi, prod_hi_H);
116 int32_t bsum_lo = (int32_t)b8->
bsums[j / 16] +
117 (int32_t)b8->
bsums[j / 16 + 1];
118 int32_t bsum_hi = (int32_t)b8->
bsums[(j + 32) / 16] +
119 (int32_t)b8->
bsums[(j + 32) / 16 + 1];
121 sumf += d * (float)sc[is] * (
float)sum_q4q8_lo;
122 sumf -= dmin * (float)m_val[is] * (
float)bsum_lo;
123 sumf += d * (float)sc[is + 1] * (
float)sum_q4q8_hi;
124 sumf -= dmin * (float)m_val[is + 1] * (
float)bsum_hi;
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
static int32_t hsum_epi32_sse(__m128i v)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)