← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k.c File Reference

GEMM/GEMV kernels with Q4_K quantized weights. More...

#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

float dot_q4_k (const void *w_q4k, const float *x, int K)
 Compute dot product of Q4_K row with FP32 vector. More...
 
void gemm_nt_q4_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_q4_k (float *Y, const void *W, const float *X, int M, int N, int K)
 Auto-dispatch GEMM based on available SIMD. More...
 
void gemm_q4_k_backward (float *dX, const void *W, const float *dY, int M, int N, int K)
 Batched backward pass. More...
 
void gemm_q4_k_ref (float *Y, const void *W, const float *X, int M, int N, int K)
 Matrix-matrix multiply with Q4_K weights (scalar reference) More...
 
void gemv_q4_k (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV based on available SIMD. More...
 
void gemv_q4_k_backward (float *dX, const void *W, const float *dY, int M, int K)
 Auto-dispatch backward. More...
 
void gemv_q4_k_backward_ref (float *dX, const void *W, const float *dY, int M, int K)
 Backward pass: compute input gradient (scalar reference) More...
 
void gemv_q4_k_ref (float *y, const void *W, const float *x, int M, int K)
 Matrix-vector multiply with Q4_K weights (scalar reference) More...
 

Detailed Description

GEMM/GEMV kernels with Q4_K quantized weights.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Implements matrix multiplication where:

  • Activations (input): FP32
  • Weights: Q4_K (4.5 bits/weight, nested scales)
  • Output: FP32

Key optimization: Fused dequantization - weights are dequantized in registers and immediately used in FMA, never written to memory.

Operations:

  • gemv_q4_k: Matrix-vector multiply (batch=1, token generation)
  • gemm_q4_k: Matrix-matrix multiply (batch>1, prefill)

Definition in file gemm_kernels_q4k.c.

Function Documentation

◆ dot_q4_k()

float dot_q4_k ( const void *  w_q4k,
const float *  x,
int  K 
)

Compute dot product of Q4_K row with FP32 vector.

Parameters
w_q4kQ4_K blocks for one row
xFP32 input vector
KVector length (must be multiple of 256)
Returns
Dot product result

Definition at line 484 of file gemm_kernels_q4k.c.

485 {
486  float result;
487  gemv_q4_k(&result, w_q4k, x, 1, K);
488  return result;
489 }
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.

References gemv_q4_k().

◆ gemm_nt_q4_k()

void gemm_nt_q4_k ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 683 of file gemm_kernels_q4k.c.

688 {
689  if (!A || !B || !C) {
690  return;
691  }
692  if (M <= 0 || N <= 0 || K <= 0) {
693  return;
694  }
695 
696  /* gemm_q4_k produces Y as [batch x M_out]. Here:
697  * batch = M (tokens)
698  * M_out = N (output channels) */
699  gemm_q4_k(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
700 
701  if (!bias) {
702  return;
703  }
704 
705  for (int i = 0; i < M; ++i) {
706  float *row = C + (size_t)i * (size_t)N;
707  for (int j = 0; j < N; ++j) {
708  row[j] += bias[j];
709  }
710  }
711 }
void gemm_q4_k(float *Y, const void *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
#define C(color)
Definition: show_config.c:39

References C, and gemm_q4_k().

Referenced by ck_attention_project_head_major_q4_k(), ck_gemm_nt_quant(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_token_q4_k(), model_decode_token(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ gemm_q4_k()

void gemm_q4_k ( float *  Y,
const void *  W,
const float *  X,
int  M,
int  N,
int  K 
)

Auto-dispatch GEMM based on available SIMD.

Definition at line 461 of file gemm_kernels_q4k.c.

465 {
466  /* Use reference implementation for correctness
467  * TODO: Fix AVX-512 version to match llama.cpp layout */
468  gemm_q4_k_ref(Y, W, X, M, N, K);
469 }
void gemm_q4_k_ref(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_K weights (scalar reference)

References gemm_q4_k_ref().

Referenced by gemm_nt_q4_k().

◆ gemm_q4_k_backward()

void gemm_q4_k_backward ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  N,
int  K 
)

Batched backward pass.

Definition at line 656 of file gemm_kernels_q4k.c.

660 {
661  for (int n = 0; n < N; n++) {
662  gemv_q4_k_backward(&dX[n * K], W, &dY[n * M], M, K);
663  }
664 }
void gemv_q4_k_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.

References gemv_q4_k_backward().

◆ gemm_q4_k_ref()

void gemm_q4_k_ref ( float *  Y,
const void *  W,
const float *  X,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply with Q4_K weights (scalar reference)

Parameters
YOutput matrix [M x N]
WWeight matrix in Q4_K format [M x K]
XInput matrix [K x N] (column-major for cache efficiency)
MNumber of output rows
NBatch size (number of columns)
KHidden dimension

Definition at line 316 of file gemm_kernels_q4k.c.

320 {
321  /* For each column in batch, use the dispatching gemv_q4_k
322  * which automatically selects AVX/AVX-512/scalar based on CPU */
323  for (int n = 0; n < N; n++) {
324  gemv_q4_k(&Y[n * M], W, &X[n * K], M, K);
325  }
326 }

References gemv_q4_k().

Referenced by gemm_q4_k().

◆ gemv_q4_k()

void gemv_q4_k ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV based on available SIMD.

Definition at line 285 of file gemm_kernels_q4k.c.

289 {
290 #ifdef __AVX512F__
291  gemv_q4_k_avx512(y, W, x, M, K);
292 #elif defined(__AVX__)
293  gemv_q4_k_avx(y, W, x, M, K);
294 #else
295  gemv_q4_k_ref(y, W, x, M, K);
296 #endif
297 }
void gemv_q4_k_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_K weights (scalar reference)

References gemv_q4_k_ref().

Referenced by attention_mlp_fused_q4k(), dot_q4_k(), gemm_q4_k_ref(), layer_fused_attn_mlp_qkv_q4k(), and rmsnorm_qkv_q4k_fused().

◆ gemv_q4_k_backward()

void gemv_q4_k_backward ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  K 
)

Auto-dispatch backward.

Definition at line 641 of file gemm_kernels_q4k.c.

645 {
646 #ifdef __AVX512F__
647  gemv_q4_k_backward_avx512(dX, W, dY, M, K);
648 #else
649  gemv_q4_k_backward_ref(dX, W, dY, M, K);
650 #endif
651 }
void gemv_q4_k_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)

References gemv_q4_k_backward_ref().

Referenced by gemm_q4_k_backward().

◆ gemv_q4_k_backward_ref()

void gemv_q4_k_backward_ref ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  K 
)

Backward pass: compute input gradient (scalar reference)

Parameters
dXOutput gradient w.r.t. input [K]
WWeight matrix in Q4_K format [M x K]
dYGradient w.r.t. output [M]
MNumber of output rows
KNumber of columns (input dimension)

Definition at line 511 of file gemm_kernels_q4k.c.

515 {
516  const block_q4_K *blocks = (const block_q4_K *)W;
517  const int blocks_per_row = K / QK_K;
518 
519  /* Zero output gradient */
520  memset(dX, 0, K * sizeof(float));
521 
522  /* Accumulate: dX += W^T @ dY
523  * Uses llama.cpp layout: 4 iterations of 64 weights each */
524  for (int row = 0; row < M; row++) {
525  const float dy = dY[row];
526 
527  for (int b = 0; b < blocks_per_row; b++) {
528  const block_q4_K *block = &blocks[row * blocks_per_row + b];
529  const float d = CK_FP16_TO_FP32(block->d);
530  const float dmin = CK_FP16_TO_FP32(block->dmin);
531 
532  uint8_t sc[8], m[8];
533  unpack_q4_k_scales(block->scales, sc, m);
534 
535  /* llama.cpp layout: 4 iterations of 64 weights each */
536  for (int iter = 0; iter < 4; iter++) {
537  const float d1 = d * (float)sc[2 * iter];
538  const float m1 = dmin * (float)m[2 * iter];
539  const float d2 = d * (float)sc[2 * iter + 1];
540  const float m2 = dmin * (float)m[2 * iter + 1];
541 
542  const uint8_t *qs = &block->qs[iter * 32];
543  float *dxp = &dX[b * QK_K + iter * 64];
544 
545  /* First 32 weights: low nibbles */
546  for (int l = 0; l < 32; l++) {
547  const int q = (qs[l] & 0x0F);
548  const float w = d1 * (float)q - m1;
549  dxp[l] += w * dy;
550  }
551 
552  /* Next 32 weights: high nibbles */
553  for (int l = 0; l < 32; l++) {
554  const int q = (qs[l] >> 4);
555  const float w = d2 * (float)q - m2;
556  dxp[32 + l] += w * dy;
557  }
558  }
559  }
560  }
561 }
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
#define QK_K
uint8_t scales[12]
uint8_t qs[256/2]
ck_half dmin

References CK_FP16_TO_FP32, block_q4_K::d, block_q4_K::dmin, QK_K, block_q4_K::qs, block_q4_K::scales, and unpack_q4_k_scales().

Referenced by gemv_q4_k_backward().

◆ gemv_q4_k_ref()

void gemv_q4_k_ref ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Matrix-vector multiply with Q4_K weights (scalar reference)

Parameters
yOutput vector [M]
WWeight matrix in Q4_K format [M x K], stored row-major
xInput vector [K]
MNumber of output rows
KNumber of columns (must be multiple of 256)

Definition at line 53 of file gemm_kernels_q4k.c.

57 {
58  const block_q4_K *blocks = (const block_q4_K *)W;
59  const int blocks_per_row = K / QK_K; /* QK_K = 256 */
60 
61  for (int row = 0; row < M; row++) {
62  float sum = 0.0f;
63 
64  for (int b = 0; b < blocks_per_row; b++) {
65  const block_q4_K *block = &blocks[row * blocks_per_row + b];
66  const float d = GGML_FP16_TO_FP32(block->d);
67  const float dmin = GGML_FP16_TO_FP32(block->dmin);
68 
69  /* Unpack sub-block scales */
70  uint8_t sc[8], m[8];
71  unpack_q4_k_scales(block->scales, sc, m);
72 
73  /* llama.cpp Q4_K layout: 4 iterations of 64 weights each
74  * Each iteration uses 32 bytes of qs and 2 scales:
75  * - First 32 weights (indices 0-31): low nibbles with scale[2*iter]
76  * - Next 32 weights (indices 32-63): high nibbles with scale[2*iter+1]
77  */
78  for (int iter = 0; iter < 4; iter++) {
79  const float d1 = d * (float)sc[2*iter];
80  const float m1 = dmin * (float)m[2*iter];
81  const float d2 = d * (float)sc[2*iter + 1];
82  const float m2 = dmin * (float)m[2*iter + 1];
83  const uint8_t *qs = &block->qs[iter * 32];
84  const float *xp = &x[b * QK_K + iter * 64];
85 
86  /* First 32 weights: low nibbles of qs[0..31] */
87  for (int l = 0; l < 32; l++) {
88  const int8_t q = (qs[l] & 0x0F);
89  sum += (d1 * (float)q - m1) * xp[l];
90  }
91  /* Next 32 weights: high nibbles of qs[0..31] */
92  for (int l = 0; l < 32; l++) {
93  const int8_t q = (qs[l] >> 4);
94  sum += (d2 * (float)q - m2) * xp[l + 32];
95  }
96  }
97  }
98 
99  y[row] = sum;
100  }
101 }
#define GGML_FP16_TO_FP32

References block_q4_K::d, block_q4_K::dmin, GGML_FP16_TO_FP32, QK_K, block_q4_K::qs, block_q4_K::scales, and unpack_q4_k_scales().

Referenced by gemv_q4_k().