← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k_avx.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q4k_avx.c
3  * @brief AVX Q4_K x Q8_K matvec kernel for Sandy/Ivy Bridge
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Uses _mm_maddubs_epi16 (SSSE3) for efficient u8*s8 multiply-add while
15  * maintaining our scale format from unpack_q4_k_scales.
16  *
17  * Key improvement over SSE: _mm_maddubs_epi16 processes 16 pairs per
18  * instruction vs SSE's _mm_cvtepu8_epi16 + _mm_madd_epi16 (8 pairs).
19  */
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 #include <string.h>
24 
25 #include "ckernel_quant.h"
26 
27 #if defined(__AVX__)
28 #include <immintrin.h>
29 
30 static inline int32_t hsum_epi32_sse(__m128i v) {
31  __m128i shuf = _mm_shuffle_epi32(v, _MM_SHUFFLE(1, 0, 3, 2));
32  __m128i sums = _mm_add_epi32(v, shuf);
33  shuf = _mm_shuffle_epi32(sums, _MM_SHUFFLE(2, 3, 0, 1));
34  sums = _mm_add_epi32(sums, shuf);
35  return _mm_cvtsi128_si32(sums);
36 }
37 
38 void gemv_q4_k_q8_k_avx(float *y,
39  const void *W,
40  const void *x_q8,
41  int M, int K)
42 {
43  const block_q4_K *blocks = (const block_q4_K *)W;
44  const block_q8_K *bx = (const block_q8_K *)x_q8;
45  const int nb = K / QK_K;
46 
47  const __m128i mask_low = _mm_set1_epi8(0x0F);
48  const __m128i ones = _mm_set1_epi16(1);
49 
50  for (int row = 0; row < M; ++row) {
51  const block_q4_K *x = blocks + row * nb;
52  float sumf = 0.0f;
53 
54  for (int i = 0; i < nb; ++i) {
55  /* Prefetch next block */
56  if (i + 1 < nb) {
57  _mm_prefetch((const char *)&x[i + 1], _MM_HINT_T0);
58  _mm_prefetch((const char *)&bx[i + 1], _MM_HINT_T0);
59  }
60 
61  /* Unpack scales using our format */
62  uint8_t sc[8], m_val[8];
63  unpack_q4_k_scales(x[i].scales, sc, m_val);
64 
65  const float d = CK_FP16_TO_FP32(x[i].d) * bx[i].d;
66  const float dmin = CK_FP16_TO_FP32(x[i].dmin) * bx[i].d;
67 
68  const uint8_t *q4 = x[i].qs;
69  const int8_t *q8 = bx[i].qs;
70 
71  int is = 0;
72 
73  /* Process 4 groups of 64 elements */
74  for (int j = 0; j < QK_K; j += 64) {
75  __m128i acc_lo = _mm_setzero_si128();
76  __m128i acc_hi = _mm_setzero_si128();
77 
78  /* Process 32 bytes of Q4 (64 elements via nibbles) */
79  for (int l = 0; l < 32; l += 16) {
80  /* Load 16 bytes of Q4 */
81  __m128i q4_vec = _mm_loadu_si128((const __m128i *)(q4 + l));
82 
83  /* Extract low and high nibbles */
84  __m128i q4_lo = _mm_and_si128(q4_vec, mask_low);
85  __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_vec, 4), mask_low);
86 
87  /* Load Q8 values - low nibbles correspond to j+l, high to j+32+l */
88  __m128i q8_lo_vec = _mm_loadu_si128((const __m128i *)(q8 + j + l));
89  __m128i q8_hi_vec = _mm_loadu_si128((const __m128i *)(q8 + j + 32 + l));
90 
91  /* _mm_maddubs_epi16: unsigned*signed multiply-add
92  * Multiplies 16 pairs of bytes, returns 8 int16 sums
93  * Result[i] = a[2i]*b[2i] + a[2i+1]*b[2i+1] */
94  __m128i prod_lo = _mm_maddubs_epi16(q4_lo, q8_lo_vec);
95  __m128i prod_hi = _mm_maddubs_epi16(q4_hi, q8_hi_vec);
96 
97  /* Convert to int32 by multiplying by 1 with madd_epi16 */
98  acc_lo = _mm_add_epi32(acc_lo, _mm_madd_epi16(prod_lo, ones));
99  acc_hi = _mm_add_epi32(acc_hi, _mm_madd_epi16(prod_hi, ones));
100  }
101 
102  int32_t sum_q4q8_lo = hsum_epi32_sse(acc_lo);
103  int32_t sum_q4q8_hi = hsum_epi32_sse(acc_hi);
104 
105  /* bsums: each bsum is 16 elements */
106  int32_t bsum_lo = (int32_t)bx[i].bsums[j / 16] +
107  (int32_t)bx[i].bsums[j / 16 + 1];
108  int32_t bsum_hi = (int32_t)bx[i].bsums[(j + 32) / 16] +
109  (int32_t)bx[i].bsums[(j + 32) / 16 + 1];
110 
111  sumf += d * (float)sc[is] * (float)sum_q4q8_lo;
112  sumf -= dmin * (float)m_val[is] * (float)bsum_lo;
113  sumf += d * (float)sc[is + 1] * (float)sum_q4q8_hi;
114  sumf -= dmin * (float)m_val[is + 1] * (float)bsum_hi;
115 
116  q4 += 32;
117  is += 2;
118  }
119  }
120  y[row] = sumf;
121  }
122 }
123 
124 /* ============================================================================
125  * PARALLEL SIMD VERSION
126  *
127  * Combines AVX SIMD with parallel row splitting for maximum throughput.
128  * OpenMP lives in orchestration layer - this kernel receives ith/nth.
129  *
130  * Prefetch strategy:
131  * - Prefetch 2-4 rows ahead (hide memory latency ~50-70ns)
132  * - Each row = (K/256) * 144 bytes = ~576 bytes for K=1024
133  * - Computation per row ~ 100ns, so prefetch 1-2 rows ahead
134  * ============================================================================ */
135 
136 void gemv_q4_k_q8_k_parallel_simd(float *y,
137  const void *W,
138  const void *x_q8,
139  int M, int K,
140  int ith, int nth)
141 {
142  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
143  if (ith < 0 || nth <= 0 || ith >= nth) return;
144 
145  /* Compute row range for this thread */
146  const int dr = (M + nth - 1) / nth;
147  const int r0 = dr * ith;
148  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
149 
150  if (r0 >= M) return;
151 
152  const block_q4_K *blocks = (const block_q4_K *)W;
153  const block_q8_K *bx = (const block_q8_K *)x_q8;
154  const int nb = K / QK_K;
155  const size_t bytes_per_row = (size_t)nb * sizeof(block_q4_K);
156 
157  const __m128i mask_low = _mm_set1_epi8(0x0F);
158  const __m128i ones = _mm_set1_epi16(1);
159 
160  /* Prefetch first few rows for this thread */
161  const int PREFETCH_ROWS = 4;
162  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
163  const char *row_ptr = (const char *)(blocks + (r0 + p) * nb);
164  _mm_prefetch(row_ptr, _MM_HINT_T0);
165  _mm_prefetch(row_ptr + 64, _MM_HINT_T0); /* Second cache line */
166  }
167 
168  for (int row = r0; row < r1; ++row) {
169  /* Prefetch rows ahead to hide memory latency */
170  if (row + PREFETCH_ROWS < r1) {
171  const char *prefetch_ptr = (const char *)(blocks + (row + PREFETCH_ROWS) * nb);
172  _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
173  _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
174  _mm_prefetch(prefetch_ptr + 128, _MM_HINT_T0);
175  }
176 
177  const block_q4_K *x = blocks + row * nb;
178  float sumf = 0.0f;
179 
180  for (int i = 0; i < nb; ++i) {
181  /* Prefetch next block within row */
182  if (i + 1 < nb) {
183  _mm_prefetch((const char *)&x[i + 1], _MM_HINT_T0);
184  }
185 
186  /* Unpack scales using our format */
187  uint8_t sc[8], m_val[8];
188  unpack_q4_k_scales(x[i].scales, sc, m_val);
189 
190  const float d = CK_FP16_TO_FP32(x[i].d) * bx[i].d;
191  const float dmin = CK_FP16_TO_FP32(x[i].dmin) * bx[i].d;
192 
193  const uint8_t *q4 = x[i].qs;
194  const int8_t *q8 = bx[i].qs;
195 
196  int is = 0;
197 
198  /* Process 4 groups of 64 elements */
199  for (int j = 0; j < QK_K; j += 64) {
200  __m128i acc_lo = _mm_setzero_si128();
201  __m128i acc_hi = _mm_setzero_si128();
202 
203  /* Process 32 bytes of Q4 (64 elements via nibbles) */
204  for (int l = 0; l < 32; l += 16) {
205  /* Load 16 bytes of Q4 */
206  __m128i q4_vec = _mm_loadu_si128((const __m128i *)(q4 + l));
207 
208  /* Extract low and high nibbles */
209  __m128i q4_lo = _mm_and_si128(q4_vec, mask_low);
210  __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_vec, 4), mask_low);
211 
212  /* Load Q8 values */
213  __m128i q8_lo_vec = _mm_loadu_si128((const __m128i *)(q8 + j + l));
214  __m128i q8_hi_vec = _mm_loadu_si128((const __m128i *)(q8 + j + 32 + l));
215 
216  /* _mm_maddubs_epi16: unsigned*signed multiply-add */
217  __m128i prod_lo = _mm_maddubs_epi16(q4_lo, q8_lo_vec);
218  __m128i prod_hi = _mm_maddubs_epi16(q4_hi, q8_hi_vec);
219 
220  /* Convert to int32 */
221  acc_lo = _mm_add_epi32(acc_lo, _mm_madd_epi16(prod_lo, ones));
222  acc_hi = _mm_add_epi32(acc_hi, _mm_madd_epi16(prod_hi, ones));
223  }
224 
225  int32_t sum_q4q8_lo = hsum_epi32_sse(acc_lo);
226  int32_t sum_q4q8_hi = hsum_epi32_sse(acc_hi);
227 
228  /* bsums: each bsum is 16 elements */
229  int32_t bsum_lo = (int32_t)bx[i].bsums[j / 16] +
230  (int32_t)bx[i].bsums[j / 16 + 1];
231  int32_t bsum_hi = (int32_t)bx[i].bsums[(j + 32) / 16] +
232  (int32_t)bx[i].bsums[(j + 32) / 16 + 1];
233 
234  sumf += d * (float)sc[is] * (float)sum_q4q8_lo;
235  sumf -= dmin * (float)m_val[is] * (float)bsum_lo;
236  sumf += d * (float)sc[is + 1] * (float)sum_q4q8_hi;
237  sumf -= dmin * (float)m_val[is + 1] * (float)bsum_hi;
238 
239  q4 += 32;
240  is += 2;
241  }
242  }
243  y[row] = sumf;
244  }
245 }
246 
247 #else
248 /* Fallback for non-AVX builds */
249 void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K);
250 
251 void gemv_q4_k_q8_k_avx(float *y,
252  const void *W,
253  const void *x_q8,
254  int M, int K)
255 {
256  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
257 }
258 
259 /* Parallel fallback when no AVX */
260 void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8,
261  int M, int K, int ith, int nth);
262 
264  const void *W,
265  const void *x_q8,
266  int M, int K,
267  int ith, int nth)
268 {
269  /* Fall back to reference parallel version */
270  gemv_q4_k_q8_k_parallel(y, W, x_q8, M, K, ith, nth);
271 }
272 #endif
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
#define QK_K
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
static int32_t hsum_epi32_sse(__m128i v)
uint8_t qs[256/2]
int8_t qs[256]
int16_t bsums[256/16]