← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemv_fused_quant_bias.c
Go to the documentation of this file.
1 /**
2  * @file gemv_fused_quant_bias.c
3  * @brief Fused GEMV kernels with online quantization and bias
4  *
5  * These kernels fuse:
6  * 1. Quantize FP32 input to Q8_0/Q8_K (no memory write)
7  * 2. GEMV with quantized weights
8  * 3. Bias add
9  *
10  * Benefits:
11  * - Eliminates memory traffic for quantized activations
12  * - Better cache utilization
13  * - Reduces total ops in IR from 3 to 1
14  *
15  * Kernel signature:
16  * gemv_fused_q5_0_bias(y, W, x, bias, M, K)
17  * - x: FP32 input [K]
18  * - W: Q5_0 weights [M, K]
19  * - bias: FP32 bias [M] (can be NULL)
20  * - y: FP32 output [M]
21  */
22 
23 #include <stdint.h>
24 #include <stddef.h>
25 #include <string.h>
26 #include <math.h>
27 #include "ckernel_quant.h"
28 
29 #if defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
30 #include <immintrin.h>
31 #endif
32 
33 /* ============================================================================
34  * Scalar Helpers
35  * ============================================================================ */
36 
37 /**
38  * @brief Round to nearest int, half away from zero (matches quantize_row_q8_0)
39  */
40 static inline int ck_round_nearest(float v) {
41  return (int)(v + (v >= 0.0f ? 0.5f : -0.5f));
42 }
43 
44 /* ============================================================================
45  * AVX + SSE SIMD Helpers
46  * ============================================================================
47  *
48  * These are local copies of helpers from gemm_kernels_q5_0.c, needed for
49  * the fused SIMD kernels. Kept local to avoid cross-file dependencies.
50  */
51 
52 #if defined(__AVX__)
53 
54 /* Combine two __m128i into __m256i (AVX without AVX2) */
55 #ifndef MM256_SET_M128I_DEFINED
56 #define MM256_SET_M128I_DEFINED
57 #define MM256_SET_M128I(hi, lo) \
58  _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1)
59 #endif
60 
61 /**
62  * @brief Spread 32 bits to 32 bytes using AVX
63  * Returns __m256i with 0xFF where bit was set, 0x00 where not
64  */
65 static inline __m256i fused_bytes_from_bits_32(const uint8_t *qh)
66 {
67  uint32_t x32;
68  memcpy(&x32, qh, sizeof(uint32_t));
69 
70  const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101LL, 0x0000000000000000LL);
71  const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303LL, 0x0202020202020202LL);
72 
73  __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
74  __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
75 
76  const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfeLL);
77 
78  bytesl = _mm_or_si128(bytesl, bit_mask);
79  bytesh = _mm_or_si128(bytesh, bit_mask);
80 
81  bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1LL));
82  bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1LL));
83 
84  return MM256_SET_M128I(bytesh, bytesl);
85 }
86 
87 /**
88  * @brief Multiply signed int8 pairs using sign trick (SSSE3)
89  * Returns 4 int32 partial sums from 16 int8 pairs
90  */
91 static inline __m128i fused_mul_sum_i8_pairs(__m128i x, __m128i y)
92 {
93  const __m128i ax = _mm_sign_epi8(x, x); /* abs(x) */
94  const __m128i sy = _mm_sign_epi8(y, x); /* y * sign(x) */
95  const __m128i dot = _mm_maddubs_epi16(ax, sy);
96  return _mm_madd_epi16(dot, _mm_set1_epi16(1));
97 }
98 
99 /**
100  * @brief Horizontal sum of 4 int32 in __m128i
101  */
102 static inline int32_t fused_hsum_i32_sse(__m128i v)
103 {
104  __m128i hi64 = _mm_unpackhi_epi64(v, v);
105  __m128i sum64 = _mm_add_epi32(hi64, v);
106  __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
107  return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
108 }
109 
110 /**
111  * @brief Quantize 32 FP32 values to int8 in SSE registers (no memory write)
112  *
113  * Uses the same algorithm as quantize_row_q8_0 (AVX path) to ensure
114  * numerical parity, but keeps results in registers instead of writing to memory.
115  *
116  * @param xp Input: 32 FP32 values
117  * @param qa_lo Output: 16 int8 quantized values [0..15]
118  * @param qa_hi Output: 16 int8 quantized values [16..31]
119  * @param out_d_x Output: quantization scale (after FP16 round-trip)
120  * @return amax (0 means all-zero input)
121  */
122 static inline float fused_quantize_block_avx(
123  const float *xp,
124  __m128i *qa_lo,
125  __m128i *qa_hi,
126  float *out_d_x)
127 {
128  const __m256 sign_bit = _mm256_set1_ps(-0.0f);
129  const __m256 v_half = _mm256_set1_ps(0.5f);
130  const __m256 v_min = _mm256_set1_ps(-127.0f);
131  const __m256 v_max = _mm256_set1_ps(127.0f);
132 
133  /* Load 32 FP32 values */
134  __m256 vx0 = _mm256_loadu_ps(&xp[0]);
135  __m256 vx1 = _mm256_loadu_ps(&xp[8]);
136  __m256 vx2 = _mm256_loadu_ps(&xp[16]);
137  __m256 vx3 = _mm256_loadu_ps(&xp[24]);
138 
139  /* Find max absolute value */
140  __m256 max_abs = _mm256_andnot_ps(sign_bit, vx0);
141  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, vx1));
142  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, vx2));
143  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, vx3));
144 
145  /* Horizontal max: 256 -> 128 -> scalar */
146  __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
147  _mm256_castps256_ps128(max_abs));
148  max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
149  max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
150  const float amax = _mm_cvtss_f32(max4);
151 
152  /* Compute scales */
153  float d_x = amax / 127.0f;
154  d_x = CK_FP16_TO_FP32(CK_FP32_TO_FP16(d_x)); /* FP16 round-trip for parity */
155  *out_d_x = d_x;
156 
157  if (amax == 0.0f) {
158  *qa_lo = _mm_setzero_si128();
159  *qa_hi = _mm_setzero_si128();
160  return 0.0f;
161  }
162 
163  const float id_x = 127.0f / amax;
164  const __m256 vmul = _mm256_set1_ps(id_x);
165 
166  /* Scale */
167  vx0 = _mm256_mul_ps(vx0, vmul);
168  vx1 = _mm256_mul_ps(vx1, vmul);
169  vx2 = _mm256_mul_ps(vx2, vmul);
170  vx3 = _mm256_mul_ps(vx3, vmul);
171 
172  /* Clamp to [-127, 127] */
173  vx0 = _mm256_min_ps(_mm256_max_ps(vx0, v_min), v_max);
174  vx1 = _mm256_min_ps(_mm256_max_ps(vx1, v_min), v_max);
175  vx2 = _mm256_min_ps(_mm256_max_ps(vx2, v_min), v_max);
176  vx3 = _mm256_min_ps(_mm256_max_ps(vx3, v_min), v_max);
177 
178  /* Round half away from zero: v + sign(v) * 0.5 */
179  vx0 = _mm256_add_ps(vx0, _mm256_or_ps(_mm256_and_ps(vx0, sign_bit), v_half));
180  vx1 = _mm256_add_ps(vx1, _mm256_or_ps(_mm256_and_ps(vx1, sign_bit), v_half));
181  vx2 = _mm256_add_ps(vx2, _mm256_or_ps(_mm256_and_ps(vx2, sign_bit), v_half));
182  vx3 = _mm256_add_ps(vx3, _mm256_or_ps(_mm256_and_ps(vx3, sign_bit), v_half));
183 
184  /* Convert to int32 (truncation after rounding) */
185  __m256i i0 = _mm256_cvttps_epi32(vx0);
186  __m256i i1 = _mm256_cvttps_epi32(vx1);
187  __m256i i2 = _mm256_cvttps_epi32(vx2);
188  __m256i i3 = _mm256_cvttps_epi32(vx3);
189 
190  /* Pack int32 -> int16 -> int8 (SSE, no AVX2 needed) */
191 #if defined(__AVX2__)
192  /* AVX2: use 256-bit packing + permute */
193  i0 = _mm256_packs_epi32(i0, i1);
194  i2 = _mm256_packs_epi32(i2, i3);
195  i0 = _mm256_packs_epi16(i0, i2);
196  const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
197  i0 = _mm256_permutevar8x32_epi32(i0, perm);
198  *qa_lo = _mm256_castsi256_si128(i0);
199  *qa_hi = _mm256_extractf128_si256(i0, 1);
200 #else
201  /* AVX (no AVX2): extract 128-bit halves and pack manually */
202  __m128i ni0 = _mm256_castsi256_si128(i0);
203  __m128i ni1 = _mm256_extractf128_si256(i0, 1);
204  __m128i ni2 = _mm256_castsi256_si128(i1);
205  __m128i ni3 = _mm256_extractf128_si256(i1, 1);
206  __m128i ni4 = _mm256_castsi256_si128(i2);
207  __m128i ni5 = _mm256_extractf128_si256(i2, 1);
208  __m128i ni6 = _mm256_castsi256_si128(i3);
209  __m128i ni7 = _mm256_extractf128_si256(i3, 1);
210 
211  ni0 = _mm_packs_epi32(ni0, ni1);
212  ni2 = _mm_packs_epi32(ni2, ni3);
213  ni4 = _mm_packs_epi32(ni4, ni5);
214  ni6 = _mm_packs_epi32(ni6, ni7);
215 
216  *qa_lo = _mm_packs_epi16(ni0, ni2);
217  *qa_hi = _mm_packs_epi16(ni4, ni6);
218 #endif
219 
220  return amax;
221 }
222 
223 /* ============================================================================
224  * AVX Fused GEMV Kernels (works on Ivy Bridge and newer)
225  * ============================================================================ */
226 
227 /**
228  * @brief AVX fused GEMV: FP32 → online Q8 → Q5_0 weights → FP32 + bias
229  *
230  * Uses AVX for float ops and SSE/SSSE3 for integer dot products.
231  */
232 static void gemv_fused_q5_0_bias_avx(
233  float *y,
234  const void *W,
235  const float *x,
236  const float *bias,
237  int M,
238  int K)
239 {
240  const block_q5_0 *blocks = (const block_q5_0 *)W;
241  const int blocks_per_row = K / QK5_0;
242 
243  /* Pre-quantize input x ONCE (not per row) */
244  float x_scales[blocks_per_row];
245  int8_t x_qs[K]; /* 32 int8 values per block */
246 
247  for (int b = 0; b < blocks_per_row; b++) {
248  __m128i qa_lo, qa_hi;
249  float d_x;
250  fused_quantize_block_avx(&x[b * QK5_0], &qa_lo, &qa_hi, &d_x);
251  x_scales[b] = d_x;
252  _mm_storeu_si128((__m128i *)&x_qs[b * 32], qa_lo);
253  _mm_storeu_si128((__m128i *)&x_qs[b * 32 + 16], qa_hi);
254  }
255 
256  const __m128i mask_0f = _mm_set1_epi8(0x0F);
257  const __m128i mask_f0 = _mm_set1_epi8((char)0xF0);
258 
259  for (int row = 0; row < M; row++) {
260  float sum = 0.0f;
261 
262  for (int b = 0; b < blocks_per_row; b++) {
263  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
264  const float d_w = CK_FP16_TO_FP32(block->d);
265  const float d_x = x_scales[b];
266  if (d_x == 0.0f) continue;
267 
268  const float d = d_w * d_x;
269 
270  /* Load pre-quantized input from buffer */
271  __m128i qa_lo = _mm_loadu_si128((const __m128i *)&x_qs[b * 32]);
272  __m128i qa_hi = _mm_loadu_si128((const __m128i *)&x_qs[b * 32 + 16]);
273 
274  /* Decode Q5_0 weights: extract nibbles */
275  __m128i qs = _mm_loadu_si128((const __m128i *)block->qs);
276  __m128i bx_lo = _mm_and_si128(qs, mask_0f);
277  __m128i bx_hi = _mm_and_si128(_mm_srli_epi16(qs, 4), mask_0f);
278 
279  /* Spread 32 high bits to 32 bytes */
280  __m256i bxhi256 = fused_bytes_from_bits_32(block->qh);
281  __m128i bxhi_lo = _mm256_castsi256_si128(bxhi256);
282  __m128i bxhi_hi = _mm256_extractf128_si256(bxhi256, 1);
283 
284  /* Apply encoding: (~bxhi) & 0xF0 */
285  bxhi_lo = _mm_andnot_si128(bxhi_lo, mask_f0);
286  bxhi_hi = _mm_andnot_si128(bxhi_hi, mask_f0);
287 
288  /* Combine: nibble | high_bit_contribution -> signed Q5_0 weight bytes */
289  bx_lo = _mm_or_si128(bx_lo, bxhi_lo);
290  bx_hi = _mm_or_si128(bx_hi, bxhi_hi);
291 
292  /* Dot product using sign trick */
293  __m128i p_lo = fused_mul_sum_i8_pairs(bx_lo, qa_lo);
294  __m128i p_hi = fused_mul_sum_i8_pairs(bx_hi, qa_hi);
295  __m128i psum = _mm_add_epi32(p_lo, p_hi);
296 
297  int32_t sumi = fused_hsum_i32_sse(psum);
298  sum += d * (float)sumi;
299  }
300 
301  if (bias) sum += bias[row];
302  y[row] = sum;
303  }
304 }
305 
306 /**
307  * @brief AVX fused GEMV: FP32 → online Q8 → Q8_0 weights → FP32 + bias
308  */
309 static void gemv_fused_q8_0_bias_avx(
310  float *y,
311  const void *W,
312  const float *x,
313  const float *bias,
314  int M,
315  int K)
316 {
317  const block_q8_0 *blocks = (const block_q8_0 *)W;
318  const int blocks_per_row = K / QK8_0;
319 
320  /* Pre-quantize input x ONCE (not per row) */
321  float x_scales[blocks_per_row];
322  int8_t x_qs[K]; /* 32 int8 values per block */
323 
324  for (int b = 0; b < blocks_per_row; b++) {
325  __m128i qa_lo, qa_hi;
326  float d_x;
327  fused_quantize_block_avx(&x[b * QK8_0], &qa_lo, &qa_hi, &d_x);
328  x_scales[b] = d_x;
329  _mm_storeu_si128((__m128i *)&x_qs[b * 32], qa_lo);
330  _mm_storeu_si128((__m128i *)&x_qs[b * 32 + 16], qa_hi);
331  }
332 
333  for (int row = 0; row < M; row++) {
334  float sum = 0.0f;
335 
336  for (int b = 0; b < blocks_per_row; b++) {
337  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
338  const float d_w = CK_FP16_TO_FP32(block->d);
339  const float d_x = x_scales[b];
340  if (d_x == 0.0f) continue;
341 
342  const float d = d_w * d_x;
343 
344  /* Load pre-quantized input from buffer */
345  __m128i qa_lo = _mm_loadu_si128((const __m128i *)&x_qs[b * 32]);
346  __m128i qa_hi = _mm_loadu_si128((const __m128i *)&x_qs[b * 32 + 16]);
347 
348  /* Load Q8_0 weights directly */
349  __m128i qw_lo = _mm_loadu_si128((const __m128i *)block->qs);
350  __m128i qw_hi = _mm_loadu_si128((const __m128i *)(block->qs + 16));
351 
352  /* Dot product using sign trick */
353  __m128i p_lo = fused_mul_sum_i8_pairs(qa_lo, qw_lo);
354  __m128i p_hi = fused_mul_sum_i8_pairs(qa_hi, qw_hi);
355  __m128i psum = _mm_add_epi32(p_lo, p_hi);
356 
357  int32_t sumi = fused_hsum_i32_sse(psum);
358  sum += d * (float)sumi;
359  }
360 
361  if (bias) sum += bias[row];
362  y[row] = sum;
363  }
364 }
365 
366 #endif /* __AVX__ */
367 
368 /* ============================================================================
369  * Scalar Reference Implementations
370  * ============================================================================ */
371 
372 /**
373  * @brief Compute dot product of FP32 input with Q5_0 weight block, with online Q8 quantization
374  */
375 static inline float dot_fp32_q5_0_block(const float *x, const block_q5_0 *block) {
376  const float d_w = CK_FP16_TO_FP32(block->d);
377 
378  float amax = 0.0f;
379  for (int j = 0; j < 32; j++) {
380  float ax = x[j] >= 0 ? x[j] : -x[j];
381  if (ax > amax) amax = ax;
382  }
383 
384  float d_x = amax / 127.0f;
385  d_x = CK_FP16_TO_FP32(CK_FP32_TO_FP16(d_x));
386  const float id_x = (amax != 0.0f) ? 127.0f / amax : 0.0f;
387  const float d = d_w * d_x;
388 
389  uint32_t qh;
390  memcpy(&qh, block->qh, sizeof(qh));
391 
392  int32_t sumi = 0;
393  for (int j = 0; j < 16; j++) {
394  const uint8_t packed = block->qs[j];
395  const int lo = (packed & 0x0F);
396  const int hi = (packed >> 4);
397  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
398  const int xh_1 = ((qh >> (j + 12))) & 0x10;
399  const int w0 = (lo | xh_0) - 16;
400  const int w1 = (hi | xh_1) - 16;
401 
402  float v0 = x[j] * id_x;
403  float v1 = x[j + 16] * id_x;
404  int q0 = ck_round_nearest(v0);
405  int q1 = ck_round_nearest(v1);
406  if (q0 > 127) q0 = 127; if (q0 < -127) q0 = -127;
407  if (q1 > 127) q1 = 127; if (q1 < -127) q1 = -127;
408 
409  sumi += q0 * w0 + q1 * w1;
410  }
411 
412  return d * (float)sumi;
413 }
414 
415 /**
416  * @brief Compute dot product of FP32 input with Q8_0 weight block, with online Q8 quantization
417  */
418 static inline float dot_fp32_q8_0_block(const float *x, const block_q8_0 *block) {
419  const float d_w = CK_FP16_TO_FP32(block->d);
420 
421  float amax = 0.0f;
422  for (int j = 0; j < 32; j++) {
423  float ax = x[j] >= 0 ? x[j] : -x[j];
424  if (ax > amax) amax = ax;
425  }
426 
427  float d_x = amax / 127.0f;
428  d_x = CK_FP16_TO_FP32(CK_FP32_TO_FP16(d_x));
429  const float id_x = (amax != 0.0f) ? 127.0f / amax : 0.0f;
430  const float d = d_w * d_x;
431 
432  int32_t sumi = 0;
433  for (int j = 0; j < 32; j++) {
434  float v = x[j] * id_x;
435  int q = ck_round_nearest(v);
436  if (q > 127) q = 127;
437  if (q < -127) q = -127;
438  sumi += q * (int32_t)block->qs[j];
439  }
440 
441  return d * (float)sumi;
442 }
443 
444 /* ============================================================================
445  * Scalar Fused GEMV Kernels (fallback)
446  * ============================================================================ */
447 
449  float *y,
450  const void *W,
451  const float *x,
452  const float *bias,
453  int M,
454  int K)
455 {
456  const block_q5_0 *blocks = (const block_q5_0 *)W;
457  const int blocks_per_row = K / QK5_0;
458 
459  for (int row = 0; row < M; row++) {
460  float sum = 0.0f;
461 
462  for (int b = 0; b < blocks_per_row; b++) {
463  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
464  const float *xp = &x[b * QK5_0];
465  sum += dot_fp32_q5_0_block(xp, block);
466  }
467 
468  if (bias) {
469  sum += bias[row];
470  }
471 
472  y[row] = sum;
473  }
474 }
475 
477  float *y,
478  const void *W,
479  const float *x,
480  const float *bias,
481  int M,
482  int K)
483 {
484  const block_q8_0 *blocks = (const block_q8_0 *)W;
485  const int blocks_per_row = K / QK8_0;
486 
487  for (int row = 0; row < M; row++) {
488  float sum = 0.0f;
489 
490  for (int b = 0; b < blocks_per_row; b++) {
491  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
492  const float *xp = &x[b * QK8_0];
493  sum += dot_fp32_q8_0_block(xp, block);
494  }
495 
496  if (bias) {
497  sum += bias[row];
498  }
499 
500  y[row] = sum;
501  }
502 }
503 
504 /* ============================================================================
505  * Dispatch Functions
506  * ============================================================================ */
507 
509  float *y,
510  const void *W,
511  const float *x,
512  const float *bias,
513  int M,
514  int K)
515 {
516 #if defined(__AVX__)
517  gemv_fused_q5_0_bias_avx(y, W, x, bias, M, K);
518 #else
519  gemv_fused_q5_0_bias(y, W, x, bias, M, K);
520 #endif
521 }
522 
524  float *y,
525  const void *W,
526  const float *x,
527  const float *bias,
528  int M,
529  int K)
530 {
531 #if defined(__AVX__)
532  gemv_fused_q8_0_bias_avx(y, W, x, bias, M, K);
533 #else
534  gemv_fused_q8_0_bias(y, W, x, bias, M, K);
535 #endif
536 }
Quantization block structures for weight-only quantization.
#define QK5_0
Definition: ckernel_quant.h:67
#define CK_FP16_TO_FP32(x)
#define CK_FP32_TO_FP16(x)
#define QK8_0
void gemv_fused_q8_0_bias(float *y, const void *W, const float *x, const float *bias, int M, int K)
static float dot_fp32_q8_0_block(const float *x, const block_q8_0 *block)
Compute dot product of FP32 input with Q8_0 weight block, with online Q8 quantization.
void gemv_fused_q8_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
static int ck_round_nearest(float v)
Round to nearest int, half away from zero (matches quantize_row_q8_0)
static float dot_fp32_q5_0_block(const float *x, const block_q5_0 *block)
Compute dot product of FP32 input with Q5_0 weight block, with online Q8 quantization.
void gemv_fused_q5_0_bias(float *y, const void *W, const float *x, const float *bias, int M, int K)
void gemv_fused_q5_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
ck_half d
Definition: ckernel_quant.h:70
uint8_t qh[4]
Definition: ckernel_quant.h:71
uint8_t qs[32/2]
Definition: ckernel_quant.h:72
int8_t qs[32]