← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_batch_int8.c File Reference

Batch GEMM kernels for quantized weights with INT8 activations. More...

#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include <math.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Macros

#define AMX_TILE_K   64
 
#define AMX_TILE_M   16
 
#define AMX_TILE_N   16
 
#define HAS_AMX   0
 
#define QK5_0   32 /* Q5_0: 32 weights per block */
 
#define QK8_0   32 /* Q8_0: 32 weights per block */
 

Functions

const char * gemm_batch_int8_impl_name (void)
 Get the best implementation name for logging/debugging. More...
 
void gemm_nt_q5_0_q8_0_ref (const void *A, const void *B, float *C, int M, int N, int K)
 Dispatcher for gemm_nt_q8_0_q8_0. More...
 
void gemm_nt_q8_0_q8_0 (const void *A, const void *B, const float *bias, float *C, int M, int N, int K)
 gemm_nt_q8_0_q8_0 with optional bias (matches header signature) More...
 
void gemm_nt_q8_0_q8_0_ref (const void *A, const void *B, float *C, int M, int N, int K)
 Scalar reference: gemm_nt_q8_0_q8_0. More...
 

Detailed Description

Batch GEMM kernels for quantized weights with INT8 activations.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Implements batch matrix multiplication where:

  • Activations (A): Q8_0 quantized (INT8 + scale)
  • Weights (B): Q5_0 or Q8_0 quantized
  • Output (C): FP32

Operation: C[M,N] = A[M,K] @ B[N,K]^T (B is transposed/row-major weights)

Instruction Set Implementations:

  • Scalar: Reference implementation for correctness verification
  • AVX: 256-bit SIMD (8 floats, or 32 int8s)
  • AVX-512: 512-bit SIMD (16 floats, or 64 int8s)
  • AMX: Intel Advanced Matrix Extensions (tile-based, requires Sapphire Rapids+)

Design Philosophy:

  • Every kernel MUST produce bit-identical results to scalar reference
  • Comprehensive testing against llama.cpp ensures correctness
  • Performance optimizations never compromise accuracy
Author
C-Kernel-Engine Team
Date
2024

Definition in file gemm_batch_int8.c.

Macro Definition Documentation

◆ AMX_TILE_K

#define AMX_TILE_K   64

Definition at line 65 of file gemm_batch_int8.c.

◆ AMX_TILE_M

#define AMX_TILE_M   16

Definition at line 63 of file gemm_batch_int8.c.

◆ AMX_TILE_N

#define AMX_TILE_N   16

Definition at line 64 of file gemm_batch_int8.c.

◆ HAS_AMX

#define HAS_AMX   0

Definition at line 52 of file gemm_batch_int8.c.

◆ QK5_0

#define QK5_0   32 /* Q5_0: 32 weights per block */

Definition at line 60 of file gemm_batch_int8.c.

◆ QK8_0

#define QK8_0   32 /* Q8_0: 32 weights per block */

Definition at line 59 of file gemm_batch_int8.c.

Function Documentation

◆ gemm_batch_int8_impl_name()

const char* gemm_batch_int8_impl_name ( void  )

Get the best implementation name for logging/debugging.

Definition at line 553 of file gemm_batch_int8.c.

554 {
555 #if HAS_AMX
556  return "AMX";
557 #elif defined(__AVX512VNNI__)
558  return "AVX-512 VNNI";
559 #elif defined(__AVX512F__)
560  return "AVX-512";
561 #elif defined(__AVX2__)
562  return "AVX2";
563 #elif defined(__AVX__)
564  return "AVX";
565 #else
566  return "Scalar";
567 #endif
568 }

◆ gemm_nt_q5_0_q8_0_ref()

void gemm_nt_q5_0_q8_0_ref ( const void *  A,
const void *  B,
float *  C,
int  M,
int  N,
int  K 
)

Dispatcher for gemm_nt_q8_0_q8_0.

Selects the best available implementation at runtime.

Scalar reference: gemm_nt_q5_0_q8_0

Q5_0 weight reconstruction: weight[j] = d * ((qs_nibble | (qh_bit << 4)) - 16)

For j in 0..15: use low nibble + qh bit j For j in 16..31: use high nibble + qh bit (j+16) -> actually bit (j) for j=16..31

Parameters
AInput activations [M, K] in Q8_0 format
BWeight matrix [N, K] in Q5_0 format
COutput matrix [M, N] in FP32
MNumber of tokens (batch size)
NNumber of output features
KNumber of input features (must be multiple of 32)

Definition at line 391 of file gemm_batch_int8.c.

396 {
397  const int nb = K / QK5_0;
398  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
399  const block_q5_0 *b_blocks = (const block_q5_0 *)B;
400 
401  for (int m = 0; m < M; m++) {
402  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
403 
404  for (int n = 0; n < N; n++) {
405  const block_q5_0 *b_row = b_blocks + (size_t)n * nb;
406  float sum = 0.0f;
407 
408  for (int ib = 0; ib < nb; ib++) {
409  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
410  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
411  const float d = d_a * d_b;
412 
413  /* Load high bits as 32-bit value */
414  uint32_t qh;
415  memcpy(&qh, b_row[ib].qh, sizeof(qh));
416 
417  int32_t sumi = 0;
418 
419  /* Process 32 weights: j=0..15 uses low nibble, j=16..31 uses high nibble */
420  for (int j = 0; j < 16; j++) {
421  /* First 16 weights: low nibble + qh bit j */
422  const uint8_t xh_0 = ((qh >> j) & 1) << 4;
423  const int8_t w0 = (int8_t)(((b_row[ib].qs[j] & 0x0F) | xh_0) - 16);
424 
425  /* Second 16 weights: high nibble + qh bit (j+16) */
426  const uint8_t xh_1 = ((qh >> (j + 16)) & 1) << 4;
427  const int8_t w1 = (int8_t)(((b_row[ib].qs[j] >> 4) | xh_1) - 16);
428 
429  /* Accumulate with activation values */
430  sumi += (int32_t)w0 * (int32_t)a_row[ib].qs[j];
431  sumi += (int32_t)w1 * (int32_t)a_row[ib].qs[j + 16];
432  }
433 
434  sum += d * (float)sumi;
435  }
436 
437  C[(size_t)m * N + n] = sum;
438  }
439  }
440 }
#define CK_FP16_TO_FP32(x)
#define QK5_0
#define C(color)
Definition: show_config.c:39
int8_t qs[32]

References C, CK_FP16_TO_FP32, QK5_0, and block_q8_0::qs.

◆ gemm_nt_q8_0_q8_0()

void gemm_nt_q8_0_q8_0 ( const void *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

gemm_nt_q8_0_q8_0 with optional bias (matches header signature)

C[m,n] = A[m,K] @ B[n,K]^T + bias[n]

Definition at line 582 of file gemm_batch_int8.c.

588 {
589  /* First compute GEMM */
590 #if defined(__AVX512VNNI__)
591  gemm_nt_q8_0_q8_0_vnni(A, B, C, M, N, K);
592 #elif defined(__AVX512F__)
593  gemm_nt_q8_0_q8_0_avx512(A, B, C, M, N, K);
594 #elif defined(__AVX2__)
595  gemm_nt_q8_0_q8_0_avx2(A, B, C, M, N, K);
596 #elif defined(__AVX__)
597  gemm_nt_q8_0_q8_0_avx(A, B, C, M, N, K);
598 #else
599  gemm_nt_q8_0_q8_0_ref(A, B, C, M, N, K);
600 #endif
601 
602  /* Add bias if provided */
603  if (bias != NULL) {
604  for (int m = 0; m < M; m++) {
605  for (int n = 0; n < N; n++) {
606  C[(size_t)m * N + n] += bias[n];
607  }
608  }
609  }
610 }
void gemm_nt_q8_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Scalar reference: gemm_nt_q8_0_q8_0.

References C, and gemm_nt_q8_0_q8_0_ref().

Referenced by ck_test_gemm_q8_0(), gemm_nt_q8_0_dispatch(), and gemm_nt_q8_0_mlp_dispatch().

◆ gemm_nt_q8_0_q8_0_ref()

void gemm_nt_q8_0_q8_0_ref ( const void *  A,
const void *  B,
float *  C,
int  M,
int  N,
int  K 
)

Scalar reference: gemm_nt_q8_0_q8_0.

C[m,n] = sum_k( dequant(A[m,k]) * dequant(B[n,k]) ) = sum_blocks( d_a * d_b * sum_j(a_qs[j] * b_qs[j]) )

Parameters
AInput activations [M, K] in Q8_0 format
BWeight matrix [N, K] in Q8_0 format (row-major, each row is one output)
COutput matrix [M, N] in FP32
MNumber of tokens (batch size)
NNumber of output features (rows in B)
KNumber of input features (must be multiple of 32)

Definition at line 87 of file gemm_batch_int8.c.

92 {
93  const int nb = K / QK8_0; /* Number of blocks per row */
94  const block_q8_0 *a_blocks = (const block_q8_0 *)A;
95  const block_q8_0 *b_blocks = (const block_q8_0 *)B;
96 
97  for (int m = 0; m < M; m++) {
98  const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
99 
100  for (int n = 0; n < N; n++) {
101  const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
102  float sum = 0.0f;
103 
104  for (int ib = 0; ib < nb; ib++) {
105  const float d_a = CK_FP16_TO_FP32(a_row[ib].d);
106  const float d_b = CK_FP16_TO_FP32(b_row[ib].d);
107  const float d = d_a * d_b;
108 
109  int32_t sumi = 0;
110  for (int j = 0; j < QK8_0; j++) {
111  sumi += (int32_t)a_row[ib].qs[j] * (int32_t)b_row[ib].qs[j];
112  }
113 
114  sum += d * (float)sumi;
115  }
116 
117  C[(size_t)m * N + n] = sum;
118  }
119  }
120 }
#define QK8_0

References C, CK_FP16_TO_FP32, QK8_0, and block_q8_0::qs.

Referenced by gemm_nt_q8_0_q8_0().