29 #if defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
30 #include <immintrin.h>
41 return (
int)(v + (v >= 0.0f ? 0.5f : -0.5f));
55 #ifndef MM256_SET_M128I_DEFINED
56 #define MM256_SET_M128I_DEFINED
57 #define MM256_SET_M128I(hi, lo) \
58 _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1)
65 static inline __m256i fused_bytes_from_bits_32(
const uint8_t *qh)
68 memcpy(&x32, qh,
sizeof(uint32_t));
70 const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101LL, 0x0000000000000000LL);
71 const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303LL, 0x0202020202020202LL);
73 __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
74 __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
76 const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL);
78 bytesl = _mm_or_si128(bytesl, bit_mask);
79 bytesh = _mm_or_si128(bytesh, bit_mask);
81 bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1LL));
82 bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1LL));
84 return MM256_SET_M128I(bytesh, bytesl);
91 static inline __m128i fused_mul_sum_i8_pairs(__m128i x, __m128i y)
93 const __m128i ax = _mm_sign_epi8(x, x);
94 const __m128i sy = _mm_sign_epi8(y, x);
95 const __m128i dot = _mm_maddubs_epi16(ax, sy);
96 return _mm_madd_epi16(dot, _mm_set1_epi16(1));
102 static inline int32_t fused_hsum_i32_sse(__m128i v)
104 __m128i hi64 = _mm_unpackhi_epi64(v, v);
105 __m128i sum64 = _mm_add_epi32(hi64, v);
106 __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
107 return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
122 static inline float fused_quantize_block_avx(
128 const __m256 sign_bit = _mm256_set1_ps(-0.0f);
129 const __m256 v_half = _mm256_set1_ps(0.5f);
130 const __m256 v_min = _mm256_set1_ps(-127.0f);
131 const __m256 v_max = _mm256_set1_ps(127.0f);
134 __m256 vx0 = _mm256_loadu_ps(&xp[0]);
135 __m256 vx1 = _mm256_loadu_ps(&xp[8]);
136 __m256 vx2 = _mm256_loadu_ps(&xp[16]);
137 __m256 vx3 = _mm256_loadu_ps(&xp[24]);
140 __m256 max_abs = _mm256_andnot_ps(sign_bit, vx0);
141 max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, vx1));
142 max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, vx2));
143 max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, vx3));
146 __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
147 _mm256_castps256_ps128(max_abs));
148 max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
149 max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
150 const float amax = _mm_cvtss_f32(max4);
153 float d_x = amax / 127.0f;
158 *qa_lo = _mm_setzero_si128();
159 *qa_hi = _mm_setzero_si128();
163 const float id_x = 127.0f / amax;
164 const __m256 vmul = _mm256_set1_ps(id_x);
167 vx0 = _mm256_mul_ps(vx0, vmul);
168 vx1 = _mm256_mul_ps(vx1, vmul);
169 vx2 = _mm256_mul_ps(vx2, vmul);
170 vx3 = _mm256_mul_ps(vx3, vmul);
173 vx0 = _mm256_min_ps(_mm256_max_ps(vx0, v_min), v_max);
174 vx1 = _mm256_min_ps(_mm256_max_ps(vx1, v_min), v_max);
175 vx2 = _mm256_min_ps(_mm256_max_ps(vx2, v_min), v_max);
176 vx3 = _mm256_min_ps(_mm256_max_ps(vx3, v_min), v_max);
179 vx0 = _mm256_add_ps(vx0, _mm256_or_ps(_mm256_and_ps(vx0, sign_bit), v_half));
180 vx1 = _mm256_add_ps(vx1, _mm256_or_ps(_mm256_and_ps(vx1, sign_bit), v_half));
181 vx2 = _mm256_add_ps(vx2, _mm256_or_ps(_mm256_and_ps(vx2, sign_bit), v_half));
182 vx3 = _mm256_add_ps(vx3, _mm256_or_ps(_mm256_and_ps(vx3, sign_bit), v_half));
185 __m256i i0 = _mm256_cvttps_epi32(vx0);
186 __m256i i1 = _mm256_cvttps_epi32(vx1);
187 __m256i i2 = _mm256_cvttps_epi32(vx2);
188 __m256i i3 = _mm256_cvttps_epi32(vx3);
191 #if defined(__AVX2__)
193 i0 = _mm256_packs_epi32(i0, i1);
194 i2 = _mm256_packs_epi32(i2, i3);
195 i0 = _mm256_packs_epi16(i0, i2);
196 const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
197 i0 = _mm256_permutevar8x32_epi32(i0, perm);
198 *qa_lo = _mm256_castsi256_si128(i0);
199 *qa_hi = _mm256_extractf128_si256(i0, 1);
202 __m128i ni0 = _mm256_castsi256_si128(i0);
203 __m128i ni1 = _mm256_extractf128_si256(i0, 1);
204 __m128i ni2 = _mm256_castsi256_si128(i1);
205 __m128i ni3 = _mm256_extractf128_si256(i1, 1);
206 __m128i ni4 = _mm256_castsi256_si128(i2);
207 __m128i ni5 = _mm256_extractf128_si256(i2, 1);
208 __m128i ni6 = _mm256_castsi256_si128(i3);
209 __m128i ni7 = _mm256_extractf128_si256(i3, 1);
211 ni0 = _mm_packs_epi32(ni0, ni1);
212 ni2 = _mm_packs_epi32(ni2, ni3);
213 ni4 = _mm_packs_epi32(ni4, ni5);
214 ni6 = _mm_packs_epi32(ni6, ni7);
216 *qa_lo = _mm_packs_epi16(ni0, ni2);
217 *qa_hi = _mm_packs_epi16(ni4, ni6);
232 static void gemv_fused_q5_0_bias_avx(
241 const int blocks_per_row = K /
QK5_0;
244 float x_scales[blocks_per_row];
247 for (
int b = 0; b < blocks_per_row; b++) {
248 __m128i qa_lo, qa_hi;
250 fused_quantize_block_avx(&x[b *
QK5_0], &qa_lo, &qa_hi, &d_x);
252 _mm_storeu_si128((__m128i *)&x_qs[b * 32], qa_lo);
253 _mm_storeu_si128((__m128i *)&x_qs[b * 32 + 16], qa_hi);
256 const __m128i mask_0f = _mm_set1_epi8(0x0F);
257 const __m128i mask_f0 = _mm_set1_epi8((
char)0xF0);
259 for (
int row = 0; row < M; row++) {
262 for (
int b = 0; b < blocks_per_row; b++) {
263 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
265 const float d_x = x_scales[b];
266 if (d_x == 0.0f)
continue;
268 const float d = d_w * d_x;
271 __m128i qa_lo = _mm_loadu_si128((
const __m128i *)&x_qs[b * 32]);
272 __m128i qa_hi = _mm_loadu_si128((
const __m128i *)&x_qs[b * 32 + 16]);
275 __m128i qs = _mm_loadu_si128((
const __m128i *)block->
qs);
276 __m128i bx_lo = _mm_and_si128(qs, mask_0f);
277 __m128i bx_hi = _mm_and_si128(_mm_srli_epi16(qs, 4), mask_0f);
280 __m256i bxhi256 = fused_bytes_from_bits_32(block->
qh);
281 __m128i bxhi_lo = _mm256_castsi256_si128(bxhi256);
282 __m128i bxhi_hi = _mm256_extractf128_si256(bxhi256, 1);
285 bxhi_lo = _mm_andnot_si128(bxhi_lo, mask_f0);
286 bxhi_hi = _mm_andnot_si128(bxhi_hi, mask_f0);
289 bx_lo = _mm_or_si128(bx_lo, bxhi_lo);
290 bx_hi = _mm_or_si128(bx_hi, bxhi_hi);
293 __m128i p_lo = fused_mul_sum_i8_pairs(bx_lo, qa_lo);
294 __m128i p_hi = fused_mul_sum_i8_pairs(bx_hi, qa_hi);
295 __m128i psum = _mm_add_epi32(p_lo, p_hi);
297 int32_t sumi = fused_hsum_i32_sse(psum);
298 sum += d * (float)sumi;
301 if (bias) sum += bias[row];
309 static void gemv_fused_q8_0_bias_avx(
318 const int blocks_per_row = K /
QK8_0;
321 float x_scales[blocks_per_row];
324 for (
int b = 0; b < blocks_per_row; b++) {
325 __m128i qa_lo, qa_hi;
327 fused_quantize_block_avx(&x[b *
QK8_0], &qa_lo, &qa_hi, &d_x);
329 _mm_storeu_si128((__m128i *)&x_qs[b * 32], qa_lo);
330 _mm_storeu_si128((__m128i *)&x_qs[b * 32 + 16], qa_hi);
333 for (
int row = 0; row < M; row++) {
336 for (
int b = 0; b < blocks_per_row; b++) {
337 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
339 const float d_x = x_scales[b];
340 if (d_x == 0.0f)
continue;
342 const float d = d_w * d_x;
345 __m128i qa_lo = _mm_loadu_si128((
const __m128i *)&x_qs[b * 32]);
346 __m128i qa_hi = _mm_loadu_si128((
const __m128i *)&x_qs[b * 32 + 16]);
349 __m128i qw_lo = _mm_loadu_si128((
const __m128i *)block->
qs);
350 __m128i qw_hi = _mm_loadu_si128((
const __m128i *)(block->
qs + 16));
353 __m128i p_lo = fused_mul_sum_i8_pairs(qa_lo, qw_lo);
354 __m128i p_hi = fused_mul_sum_i8_pairs(qa_hi, qw_hi);
355 __m128i psum = _mm_add_epi32(p_lo, p_hi);
357 int32_t sumi = fused_hsum_i32_sse(psum);
358 sum += d * (float)sumi;
361 if (bias) sum += bias[row];
379 for (
int j = 0; j < 32; j++) {
380 float ax = x[j] >= 0 ? x[j] : -x[j];
381 if (ax > amax) amax = ax;
384 float d_x = amax / 127.0f;
386 const float id_x = (amax != 0.0f) ? 127.0f / amax : 0.0f;
387 const float d = d_w * d_x;
390 memcpy(&qh, block->
qh,
sizeof(qh));
393 for (
int j = 0; j < 16; j++) {
394 const uint8_t packed = block->
qs[j];
395 const int lo = (packed & 0x0F);
396 const int hi = (packed >> 4);
397 const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
398 const int xh_1 = ((qh >> (j + 12))) & 0x10;
399 const int w0 = (lo | xh_0) - 16;
400 const int w1 = (hi | xh_1) - 16;
402 float v0 = x[j] * id_x;
403 float v1 = x[j + 16] * id_x;
406 if (q0 > 127) q0 = 127;
if (q0 < -127) q0 = -127;
407 if (q1 > 127) q1 = 127;
if (q1 < -127) q1 = -127;
409 sumi += q0 * w0 + q1 * w1;
412 return d * (float)sumi;
422 for (
int j = 0; j < 32; j++) {
423 float ax = x[j] >= 0 ? x[j] : -x[j];
424 if (ax > amax) amax = ax;
427 float d_x = amax / 127.0f;
429 const float id_x = (amax != 0.0f) ? 127.0f / amax : 0.0f;
430 const float d = d_w * d_x;
433 for (
int j = 0; j < 32; j++) {
434 float v = x[j] * id_x;
436 if (q > 127) q = 127;
437 if (q < -127) q = -127;
438 sumi += q * (int32_t)block->
qs[j];
441 return d * (float)sumi;
457 const int blocks_per_row = K /
QK5_0;
459 for (
int row = 0; row < M; row++) {
462 for (
int b = 0; b < blocks_per_row; b++) {
463 const block_q5_0 *block = &blocks[row * blocks_per_row + b];
464 const float *xp = &x[b *
QK5_0];
485 const int blocks_per_row = K /
QK8_0;
487 for (
int row = 0; row < M; row++) {
490 for (
int b = 0; b < blocks_per_row; b++) {
491 const block_q8_0 *block = &blocks[row * blocks_per_row + b];
492 const float *xp = &x[b *
QK8_0];
517 gemv_fused_q5_0_bias_avx(y, W, x, bias, M, K);
532 gemv_fused_q8_0_bias_avx(y, W, x, bias, M, K);
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define CK_FP32_TO_FP16(x)
void gemv_fused_q8_0_bias(float *y, const void *W, const float *x, const float *bias, int M, int K)
static float dot_fp32_q8_0_block(const float *x, const block_q8_0 *block)
Compute dot product of FP32 input with Q8_0 weight block, with online Q8 quantization.
void gemv_fused_q8_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
static int ck_round_nearest(float v)
Round to nearest int, half away from zero (matches quantize_row_q8_0)
static float dot_fp32_q5_0_block(const float *x, const block_q5_0 *block)
Compute dot product of FP32 input with Q5_0 weight block, with online Q8 quantization.
void gemv_fused_q5_0_bias(float *y, const void *W, const float *x, const float *bias, int M, int K)
void gemv_fused_q5_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)