← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_batch_int8.c
Go to the documentation of this file.
1 /**
2  * @file gemm_batch_int8.c
3  * @brief Batch GEMM kernels for quantized weights with INT8 activations
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Implements batch matrix multiplication where:
15  * - Activations (A): Q8_0 quantized (INT8 + scale)
16  * - Weights (B): Q5_0 or Q8_0 quantized
17  * - Output (C): FP32
18  *
19  * Operation: C[M,N] = A[M,K] @ B[N,K]^T (B is transposed/row-major weights)
20  *
21  * Instruction Set Implementations:
22  * - Scalar: Reference implementation for correctness verification
23  * - AVX: 256-bit SIMD (8 floats, or 32 int8s)
24  * - AVX-512: 512-bit SIMD (16 floats, or 64 int8s)
25  * - AMX: Intel Advanced Matrix Extensions (tile-based, requires Sapphire Rapids+)
26  *
27  * Design Philosophy:
28  * - Every kernel MUST produce bit-identical results to scalar reference
29  * - Comprehensive testing against llama.cpp ensures correctness
30  * - Performance optimizations never compromise accuracy
31  *
32  * @author C-Kernel-Engine Team
33  * @date 2024
34  */
35 
36 #include <stdint.h>
37 #include <stddef.h>
38 #include <string.h>
39 #include <math.h>
40 #include "ckernel_quant.h"
41 
42 /* SIMD headers */
43 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
44 #include <immintrin.h>
45 #endif
46 
47 /* AMX headers (requires specific compiler support) */
48 #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
49 #include <immintrin.h>
50 #define HAS_AMX 1
51 #else
52 #define HAS_AMX 0
53 #endif
54 
55 /* ============================================================================
56  * Constants and Block Sizes
57  * ============================================================================ */
58 
59 #define QK8_0 32 /* Q8_0: 32 weights per block */
60 #define QK5_0 32 /* Q5_0: 32 weights per block */
61 
62 /* AMX tile dimensions */
63 #define AMX_TILE_M 16
64 #define AMX_TILE_N 16
65 #define AMX_TILE_K 64
66 
67 /* ============================================================================
68  * SECTION 1: GEMM Q8_0 x Q8_0 -> FP32
69  *
70  * Both weights and activations are Q8_0 quantized.
71  * This is the simplest case - direct INT8 x INT8 -> INT32 accumulation.
72  * ============================================================================ */
73 
74 /**
75  * @brief Scalar reference: gemm_nt_q8_0_q8_0
76  *
77  * C[m,n] = sum_k( dequant(A[m,k]) * dequant(B[n,k]) )
78  * = sum_blocks( d_a * d_b * sum_j(a_qs[j] * b_qs[j]) )
79  *
80  * @param A Input activations [M, K] in Q8_0 format
81  * @param B Weight matrix [N, K] in Q8_0 format (row-major, each row is one output)
82  * @param C Output matrix [M, N] in FP32
83  * @param M Number of tokens (batch size)
84  * @param N Number of output features (rows in B)
85  * @param K Number of input features (must be multiple of 32)
86  */
88  const void *A,
89  const void *B,
90  float *C,
91  int M, int N, int K)
92 {
93  const int nb = K / QK8_0; /* Number of blocks per row */
94  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
95  const block_q8_0 *b_blocks = (const block_q8_0 *)B;
96 
97  for (int m = 0; m < M; m++) {
98  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
99 
100  for (int n = 0; n < N; n++) {
101  const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
102  float sum = 0.0f;
103 
104  for (int ib = 0; ib < nb; ib++) {
105  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
106  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
107  const float d = d_a * d_b;
108 
109  int32_t sumi = 0;
110  for (int j = 0; j < QK8_0; j++) {
111  sumi += (int32_t)a_row[ib].qs[j] * (int32_t)b_row[ib].qs[j];
112  }
113 
114  sum += d * (float)sumi;
115  }
116 
117  C[(size_t)m * N + n] = sum;
118  }
119  }
120 }
121 
122 #if defined(__AVX2__)
123 /**
124  * @brief AVX2 implementation: gemm_nt_q8_0_q8_0
125  *
126  * Uses 256-bit vectors to process 32 int8 values at once.
127  * Requires AVX2 for _mm256_cvtepi8_epi16, _mm256_madd_epi16, etc.
128  * Accumulates in INT32, then scales by d_a * d_b.
129  */
130 void gemm_nt_q8_0_q8_0_avx2(
131  const void *A,
132  const void *B,
133  float *C,
134  int M, int N, int K)
135 {
136  const int nb = K / QK8_0;
137  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
138  const block_q8_0 *b_blocks = (const block_q8_0 *)B;
139 
140  for (int m = 0; m < M; m++) {
141  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
142 
143  for (int n = 0; n < N; n++) {
144  const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
145  float sum = 0.0f;
146 
147  for (int ib = 0; ib < nb; ib++) {
148  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
149  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
150  const float d = d_a * d_b;
151 
152  /* Load 32 int8 values from A and B */
153  __m256i va = _mm256_loadu_si256((const __m256i *)a_row[ib].qs);
154  __m256i vb = _mm256_loadu_si256((const __m256i *)b_row[ib].qs);
155 
156  /* Split into 16-bit for multiplication without overflow
157  * Process low 16 bytes and high 16 bytes separately */
158  __m128i va_lo = _mm256_castsi256_si128(va);
159  __m128i va_hi = _mm256_extracti128_si256(va, 1);
160  __m128i vb_lo = _mm256_castsi256_si128(vb);
161  __m128i vb_hi = _mm256_extracti128_si256(vb, 1);
162 
163  /* Extend to 16-bit and multiply */
164  __m256i va_lo_16 = _mm256_cvtepi8_epi16(va_lo);
165  __m256i vb_lo_16 = _mm256_cvtepi8_epi16(vb_lo);
166  __m256i va_hi_16 = _mm256_cvtepi8_epi16(va_hi);
167  __m256i vb_hi_16 = _mm256_cvtepi8_epi16(vb_hi);
168 
169  __m256i prod_lo = _mm256_mullo_epi16(va_lo_16, vb_lo_16);
170  __m256i prod_hi = _mm256_mullo_epi16(va_hi_16, vb_hi_16);
171 
172  /* Horizontal sum: extend to 32-bit and add */
173  __m256i sum_lo = _mm256_madd_epi16(prod_lo, _mm256_set1_epi16(1));
174  __m256i sum_hi = _mm256_madd_epi16(prod_hi, _mm256_set1_epi16(1));
175  __m256i sum_32 = _mm256_add_epi32(sum_lo, sum_hi);
176 
177  /* Reduce 8 x int32 to single int32 */
178  __m128i sum_128 = _mm_add_epi32(
179  _mm256_castsi256_si128(sum_32),
180  _mm256_extracti128_si256(sum_32, 1)
181  );
182  sum_128 = _mm_add_epi32(sum_128, _mm_srli_si128(sum_128, 8));
183  sum_128 = _mm_add_epi32(sum_128, _mm_srli_si128(sum_128, 4));
184  int32_t sumi = _mm_cvtsi128_si32(sum_128);
185 
186  sum += d * (float)sumi;
187  }
188 
189  C[(size_t)m * N + n] = sum;
190  }
191  }
192 }
193 #endif /* __AVX2__ */
194 
195 #if defined(__AVX__) && !defined(__AVX2__)
196 /**
197  * @brief AVX (SSE4.1) implementation: gemm_nt_q8_0_q8_0
198  *
199  * Uses 128-bit SSE4.1 intrinsics to process 32 int8 values per block
200  * in 4 chunks of 8. Available on all AVX-capable CPUs (Sandy Bridge+).
201  * Fills the gap between AVX2 and scalar fallback.
202  */
203 void gemm_nt_q8_0_q8_0_avx(
204  const void *A,
205  const void *B,
206  float *C,
207  int M, int N, int K)
208 {
209  const int nb = K / QK8_0;
210  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
211  const block_q8_0 *b_blocks = (const block_q8_0 *)B;
212 
213  for (int m = 0; m < M; m++) {
214  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
215  for (int n = 0; n < N; n++) {
216  const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
217  float sum = 0.0f;
218 
219  for (int ib = 0; ib < nb; ib++) {
220  const float d = CK_FP16_TO_FP32(a_row[ib].d)
221  * CK_FP16_TO_FP32(b_row[ib].d);
222  const int8_t *a_qs = a_row[ib].qs;
223  const int8_t *b_qs = b_row[ib].qs;
224 
225  /* 4 chunks of 8 int8 values: load, sign-extend to int16, madd to int32 */
226  __m128i d0 = _mm_madd_epi16(
227  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(a_qs + 0))),
228  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(b_qs + 0))));
229  __m128i d1 = _mm_madd_epi16(
230  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(a_qs + 8))),
231  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(b_qs + 8))));
232  __m128i d2 = _mm_madd_epi16(
233  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(a_qs + 16))),
234  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(b_qs + 16))));
235  __m128i d3 = _mm_madd_epi16(
236  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(a_qs + 24))),
237  _mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i *)(b_qs + 24))));
238 
239  /* Reduce 4x4 int32 lanes to single int32 */
240  __m128i s4 = _mm_add_epi32(_mm_add_epi32(d0, d1),
241  _mm_add_epi32(d2, d3));
242  s4 = _mm_add_epi32(s4, _mm_srli_si128(s4, 8));
243  s4 = _mm_add_epi32(s4, _mm_srli_si128(s4, 4));
244  sum += d * (float)_mm_cvtsi128_si32(s4);
245  }
246  C[(size_t)m * N + n] = sum;
247  }
248  }
249 }
250 #endif /* __AVX__ && !__AVX2__ */
251 
252 #if defined(__AVX512F__)
253 /**
254  * @brief AVX-512 implementation: gemm_nt_q8_0_q8_0
255  *
256  * Uses 512-bit vectors to process 64 int8 values at once.
257  * With VNNI, can use _mm512_dpbusd for even faster int8 dot products.
258  */
259 void gemm_nt_q8_0_q8_0_avx512(
260  const void *A,
261  const void *B,
262  float *C,
263  int M, int N, int K)
264 {
265  const int nb = K / QK8_0;
266  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
267  const block_q8_0 *b_blocks = (const block_q8_0 *)B;
268 
269  for (int m = 0; m < M; m++) {
270  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
271 
272  for (int n = 0; n < N; n++) {
273  const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
274  float sum = 0.0f;
275 
276  for (int ib = 0; ib < nb; ib++) {
277  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
278  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
279  const float d = d_a * d_b;
280 
281  /* Load 32 int8 values - use 256-bit load, extend to 512-bit for processing */
282  __m256i va_256 = _mm256_loadu_si256((const __m256i *)a_row[ib].qs);
283  __m256i vb_256 = _mm256_loadu_si256((const __m256i *)b_row[ib].qs);
284 
285  /* Extend int8 to int16 for multiplication */
286  __m512i va_16 = _mm512_cvtepi8_epi16(va_256);
287  __m512i vb_16 = _mm512_cvtepi8_epi16(vb_256);
288 
289  /* Multiply 32 pairs of int16 -> int16 (no overflow for int8*int8) */
290  __m512i prod = _mm512_mullo_epi16(va_16, vb_16);
291 
292  /* Sum adjacent pairs to int32: madd adds pairs of int16 products */
293  __m512i sum_32 = _mm512_madd_epi16(prod, _mm512_set1_epi16(1));
294 
295  /* Reduce all 16 int32 lanes to single int32 */
296  int32_t sumi = _mm512_reduce_add_epi32(sum_32);
297 
298  sum += d * (float)sumi;
299  }
300 
301  C[(size_t)m * N + n] = sum;
302  }
303  }
304 }
305 
306 #if defined(__AVX512VNNI__)
307 /**
308  * @brief AVX-512 VNNI implementation: gemm_nt_q8_0_q8_0
309  *
310  * Uses VNNI instructions (_mm512_dpbusd_epi32) for optimal int8 dot products.
311  * VNNI computes: acc += sum(a[i] * b[i]) for 4 int8 pairs at once.
312  *
313  * Note: VNNI expects unsigned * signed for dpbusd, so we need to handle
314  * signed * signed carefully using dpbssd or offset trick.
315  */
316 void gemm_nt_q8_0_q8_0_vnni(
317  const void *A,
318  const void *B,
319  float *C,
320  int M, int N, int K)
321 {
322  const int nb = K / QK8_0;
323  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
324  const block_q8_0 *b_blocks = (const block_q8_0 *)B;
325 
326  for (int m = 0; m < M; m++) {
327  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
328 
329  for (int n = 0; n < N; n++) {
330  const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
331  float sum = 0.0f;
332 
333  for (int ib = 0; ib < nb; ib++) {
334  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
335  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
336  const float d = d_a * d_b;
337 
338  /* Load 32 int8 values */
339  __m256i va = _mm256_loadu_si256((const __m256i *)a_row[ib].qs);
340  __m256i vb = _mm256_loadu_si256((const __m256i *)b_row[ib].qs);
341 
342  /* For signed*signed, use extend to 16-bit approach
343  * (dpbusd is unsigned*signed, dpbssd requires AVX512_VNNI_INT8) */
344  __m512i va_16 = _mm512_cvtepi8_epi16(va);
345  __m512i vb_16 = _mm512_cvtepi8_epi16(vb);
346  __m512i prod = _mm512_mullo_epi16(va_16, vb_16);
347  __m512i sum_32 = _mm512_madd_epi16(prod, _mm512_set1_epi16(1));
348  int32_t sumi = _mm512_reduce_add_epi32(sum_32);
349 
350  sum += d * (float)sumi;
351  }
352 
353  C[(size_t)m * N + n] = sum;
354  }
355  }
356 }
357 #endif /* __AVX512VNNI__ */
358 #endif /* __AVX512F__ */
359 
360 /**
361  * @brief Dispatcher for gemm_nt_q8_0_q8_0
362  *
363  * Selects the best available implementation at runtime.
364  */
365 /* Dispatcher is now gemm_nt_q8_0_q8_0_bias in Section 5 */
366 
367 
368 /* ============================================================================
369  * SECTION 2: GEMM Q5_0 x Q8_0 -> FP32
370  *
371  * Weights are Q5_0 (5-bit), activations are Q8_0 (8-bit).
372  * Q5_0 requires unpacking: 4 bits from qs[] + 1 bit from qh[].
373  * ============================================================================ */
374 
375 /**
376  * @brief Scalar reference: gemm_nt_q5_0_q8_0
377  *
378  * Q5_0 weight reconstruction:
379  * weight[j] = d * ((qs_nibble | (qh_bit << 4)) - 16)
380  *
381  * For j in 0..15: use low nibble + qh bit j
382  * For j in 16..31: use high nibble + qh bit (j+16) -> actually bit (j) for j=16..31
383  *
384  * @param A Input activations [M, K] in Q8_0 format
385  * @param B Weight matrix [N, K] in Q5_0 format
386  * @param C Output matrix [M, N] in FP32
387  * @param M Number of tokens (batch size)
388  * @param N Number of output features
389  * @param K Number of input features (must be multiple of 32)
390  */
392  const void *A,
393  const void *B,
394  float *C,
395  int M, int N, int K)
396 {
397  const int nb = K / QK5_0;
398  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
399  const block_q5_0 *b_blocks = (const block_q5_0 *)B;
400 
401  for (int m = 0; m < M; m++) {
402  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
403 
404  for (int n = 0; n < N; n++) {
405  const block_q5_0 *b_row = b_blocks + (size_t)n * nb;
406  float sum = 0.0f;
407 
408  for (int ib = 0; ib < nb; ib++) {
409  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
410  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
411  const float d = d_a * d_b;
412 
413  /* Load high bits as 32-bit value */
414  uint32_t qh;
415  memcpy(&qh, b_row[ib].qh, sizeof(qh));
416 
417  int32_t sumi = 0;
418 
419  /* Process 32 weights: j=0..15 uses low nibble, j=16..31 uses high nibble */
420  for (int j = 0; j < 16; j++) {
421  /* First 16 weights: low nibble + qh bit j */
422  const uint8_t xh_0 = ((qh >> j) & 1) << 4;
423  const int8_t w0 = (int8_t)(((b_row[ib].qs[j] & 0x0F) | xh_0) - 16);
424 
425  /* Second 16 weights: high nibble + qh bit (j+16) */
426  const uint8_t xh_1 = ((qh >> (j + 16)) & 1) << 4;
427  const int8_t w1 = (int8_t)(((b_row[ib].qs[j] >> 4) | xh_1) - 16);
428 
429  /* Accumulate with activation values */
430  sumi += (int32_t)w0 * (int32_t)a_row[ib].qs[j];
431  sumi += (int32_t)w1 * (int32_t)a_row[ib].qs[j + 16];
432  }
433 
434  sum += d * (float)sumi;
435  }
436 
437  C[(size_t)m * N + n] = sum;
438  }
439  }
440 }
441 
442 /* ============================================================================
443  * SECTION 3: AMX Implementation (Intel Advanced Matrix Extensions)
444  *
445  * AMX uses tile registers (TMM0-TMM7) for matrix operations.
446  * Each tile can hold up to 16 rows x 64 bytes (1KB).
447  *
448  * Key operations:
449  * _tile_loadd: Load tile from memory
450  * _tile_dpbssd: Signed int8 dot product accumulate (A signed, B signed)
451  * _tile_stored: Store tile to memory
452  *
453  * Requirements:
454  * - Sapphire Rapids or later CPU
455  * - __AMX_INT8__ defined
456  * - OS support (XSAVE/XRSTOR for tiles)
457  * ============================================================================ */
458 
459 #if HAS_AMX
460 
461 /* AMX tile configuration */
462 typedef struct {
463  uint8_t palette_id;
464  uint8_t start_row;
465  uint8_t reserved[14];
466  uint16_t colsb[8];
467  uint8_t rows[8];
468 } tile_config_t;
469 
470 static void amx_tile_config_init(void)
471 {
472  static __thread int initialized = 0;
473  if (initialized) return;
474 
475  tile_config_t tc = {0};
476  tc.palette_id = 1;
477 
478  /* Configure tiles for our GEMM pattern:
479  * TMM0: accumulator C (16 rows x 16 cols of int32)
480  * TMM1: A tile (16 rows x 64 bytes = 64 int8 per row)
481  * TMM2: B tile (16 rows x 64 bytes)
482  */
483  tc.rows[0] = 16; tc.colsb[0] = 64; /* TMM0: 16x16 int32 */
484  tc.rows[1] = 16; tc.colsb[1] = 64; /* TMM1: A */
485  tc.rows[2] = 16; tc.colsb[2] = 64; /* TMM2: B */
486  tc.rows[3] = 16; tc.colsb[3] = 64; /* TMM3: spare */
487  tc.rows[4] = 16; tc.colsb[4] = 64; /* TMM4: spare */
488  tc.rows[5] = 16; tc.colsb[5] = 64; /* TMM5: spare */
489  tc.rows[6] = 16; tc.colsb[6] = 64; /* TMM6: spare */
490  tc.rows[7] = 16; tc.colsb[7] = 64; /* TMM7: spare */
491 
492  _tile_loadconfig(&tc);
493  initialized = 1;
494 }
495 
496 /**
497  * @brief AMX implementation: gemm_nt_q8_0_q8_0
498  *
499  * Uses AMX tiles for matrix multiplication.
500  * This is a simplified version - full implementation would tile the problem.
501  *
502  * Note: AMX requires specific data layout and tiling strategy.
503  * This implementation focuses on correctness; optimization is future work.
504  */
505 void gemm_nt_q8_0_q8_0_amx(
506  const void *A,
507  const void *B,
508  float *C,
509  int M, int N, int K)
510 {
511  amx_tile_config_init();
512 
513  /* For now, fall back to AVX-512 implementation.
514  * Full AMX implementation requires:
515  * 1. Repacking data for tile-friendly layout
516  * 2. Proper tile blocking (16x16 tiles)
517  * 3. Scale factor handling after tile operations
518  *
519  * TODO: Implement full AMX path when we have test infrastructure
520  */
521  gemm_nt_q8_0_q8_0_avx512(A, B, C, M, N, K);
522 }
523 
524 void gemm_nt_q5_0_q8_0_amx(
525  const void *A,
526  const void *B,
527  float *C,
528  int M, int N, int K)
529 {
530  amx_tile_config_init();
531 
532  /* Q5_0 requires unpacking before AMX can process.
533  * Strategy:
534  * 1. Unpack Q5_0 to int8 buffer
535  * 2. Use AMX for the actual GEMM
536  * 3. Apply scales
537  *
538  * For now, fall back to scalar reference.
539  */
540  gemm_nt_q5_0_q8_0_ref(A, B, C, M, N, K);
541 }
542 
543 #endif /* HAS_AMX */
544 
545 
546 /* ============================================================================
547  * SECTION 4: API Functions with Full Dispatch
548  * ============================================================================ */
549 
550 /**
551  * @brief Get the best implementation name for logging/debugging
552  */
553 const char* gemm_batch_int8_impl_name(void)
554 {
555 #if HAS_AMX
556  return "AMX";
557 #elif defined(__AVX512VNNI__)
558  return "AVX-512 VNNI";
559 #elif defined(__AVX512F__)
560  return "AVX-512";
561 #elif defined(__AVX2__)
562  return "AVX2";
563 #elif defined(__AVX__)
564  return "AVX";
565 #else
566  return "Scalar";
567 #endif
568 }
569 
570 
571 /* ============================================================================
572  * SECTION 5: API Wrappers with Bias Support
573  *
574  * These match the existing API signature in ckernel_quant.h
575  * ============================================================================ */
576 
577 /**
578  * @brief gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
579  *
580  * C[m,n] = A[m,K] @ B[n,K]^T + bias[n]
581  */
583  const void *A,
584  const void *B,
585  const float *bias,
586  float *C,
587  int M, int N, int K)
588 {
589  /* First compute GEMM */
590 #if defined(__AVX512VNNI__)
591  gemm_nt_q8_0_q8_0_vnni(A, B, C, M, N, K);
592 #elif defined(__AVX512F__)
593  gemm_nt_q8_0_q8_0_avx512(A, B, C, M, N, K);
594 #elif defined(__AVX2__)
595  gemm_nt_q8_0_q8_0_avx2(A, B, C, M, N, K);
596 #elif defined(__AVX__)
597  gemm_nt_q8_0_q8_0_avx(A, B, C, M, N, K);
598 #else
599  gemm_nt_q8_0_q8_0_ref(A, B, C, M, N, K);
600 #endif
601 
602  /* Add bias if provided */
603  if (bias != NULL) {
604  for (int m = 0; m < M; m++) {
605  for (int n = 0; n < N; n++) {
606  C[(size_t)m * N + n] += bias[n];
607  }
608  }
609  }
610 }
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define QK5_0
void gemm_nt_q5_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Dispatcher for gemm_nt_q8_0_q8_0.
void gemm_nt_q8_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Scalar reference: gemm_nt_q8_0_q8_0.
const char * gemm_batch_int8_impl_name(void)
Get the best implementation name for logging/debugging.
void gemm_nt_q8_0_q8_0(const void *A, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
#define QK8_0
#define C(color)
Definition: show_config.c:39
int8_t qs[32]