← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4_0.c File Reference

GEMM/GEMV kernels with Q4_0 quantized weights. More...

#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

float dot_q4_0 (const void *w_q4_0, const float *x, int K)
 
void gemm_nt_q4_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More...
 
void gemm_q4_0 (float *Y, const void *W, const float *X, int M, int N, int K)
 Matrix-matrix multiply with Q4_0 weights. More...
 
void gemm_q4_0_backward (float *dX, const void *W, const float *dY, int M, int N, int K)
 Batched backward pass. More...
 
void gemv_q4_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV. More...
 
void gemv_q4_0_backward (float *dX, const void *W, const float *dY, int M, int K)
 Auto-dispatch backward. More...
 
void gemv_q4_0_backward_ref (float *dX, const void *W, const float *dY, int M, int K)
 Backward pass: compute input gradient. More...
 
void gemv_q4_0_ref (float *y, const void *W, const float *x, int M, int K)
 Matrix-vector multiply with Q4_0 weights (scalar reference) More...
 

Detailed Description

GEMM/GEMV kernels with Q4_0 quantized weights.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Q4_0 Format:

  • 32 weights per block
  • 1 FP16 scale per block
  • 18 bytes per 32 weights = 4.5 bits/weight

Operations: Forward: Y = W @ X (W is Q4_0, X and Y are FP32) Backward: dX = W^T @ dY (gradient w.r.t. input)

Note: Weight gradients are not computed for quantized weights. For fine-tuning, use LoRA adapters which maintain FP32 gradients separately.

Definition in file gemm_kernels_q4_0.c.

Function Documentation

◆ dot_q4_0()

float dot_q4_0 ( const void *  w_q4_0,
const float *  x,
int  K 
)

Definition at line 347 of file gemm_kernels_q4_0.c.

348 {
349  float result;
350  gemv_q4_0(&result, w_q4_0, x, 1, K);
351  return result;
352 }
void gemv_q4_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.

References gemv_q4_0().

◆ gemm_nt_q4_0()

void gemm_nt_q4_0 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.

Parameters
AInput matrix [M x K], row-major FP32
BWeight matrix in Q4_0 format, [N x K] stored row-major
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension (number of rows in B)
KInput dimension

Definition at line 176 of file gemm_kernels_q4_0.c.

181 {
182  const block_q4_0 *blocks = (const block_q4_0 *)B;
183  const int blocks_per_row = K / QK4_0;
184 
185  for (int m = 0; m < M; m++) {
186  const float *a_row = &A[m * K];
187 
188  for (int n = 0; n < N; n++) {
189  float sum = 0.0f;
190 
191  for (int b = 0; b < blocks_per_row; b++) {
192  const block_q4_0 *block = &blocks[n * blocks_per_row + b];
193  const float d = CK_FP16_TO_FP32(block->d);
194  const float *ap = &a_row[b * QK4_0];
195 
196  for (int i = 0; i < QK4_0 / 2; i++) {
197  const uint8_t packed = block->qs[i];
198  const int q0 = (packed & 0x0F) - 8;
199  const int q1 = (packed >> 4) - 8;
200 
201  sum += d * (float)q0 * ap[2 * i + 0];
202  sum += d * (float)q1 * ap[2 * i + 1];
203  }
204  }
205 
206  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
207  }
208  }
209 }
#define CK_FP16_TO_FP32(x)
#define QK4_0
Definition: ckernel_quant.h:35
#define C(color)
Definition: show_config.c:39
ck_half d
Definition: ckernel_quant.h:38
uint8_t qs[32/2]
Definition: ckernel_quant.h:39

References C, CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.

Referenced by ck_gemm_nt_quant().

◆ gemm_q4_0()

void gemm_q4_0 ( float *  Y,
const void *  W,
const float *  X,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply with Q4_0 weights.

Definition at line 151 of file gemm_kernels_q4_0.c.

155 {
156  for (int n = 0; n < N; n++) {
157  gemv_q4_0(&Y[n * M], W, &X[n * K], M, K);
158  }
159 }

References gemv_q4_0().

◆ gemm_q4_0_backward()

void gemm_q4_0_backward ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  N,
int  K 
)

Batched backward pass.

Definition at line 333 of file gemm_kernels_q4_0.c.

337 {
338  for (int n = 0; n < N; n++) {
339  gemv_q4_0_backward(&dX[n * K], W, &dY[n * M], M, K);
340  }
341 }
void gemv_q4_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.

References gemv_q4_0_backward().

◆ gemv_q4_0()

void gemv_q4_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV.

Definition at line 132 of file gemm_kernels_q4_0.c.

136 {
137 #ifdef __AVX512F__
138  gemv_q4_0_avx512(y, W, x, M, K);
139 #else
140  gemv_q4_0_ref(y, W, x, M, K);
141 #endif
142 }
void gemv_q4_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_0 weights (scalar reference)

References gemv_q4_0_ref().

Referenced by dot_q4_0(), and gemm_q4_0().

◆ gemv_q4_0_backward()

void gemv_q4_0_backward ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  K 
)

Auto-dispatch backward.

Definition at line 318 of file gemm_kernels_q4_0.c.

322 {
323 #ifdef __AVX512F__
324  gemv_q4_0_backward_avx512(dX, W, dY, M, K);
325 #else
326  gemv_q4_0_backward_ref(dX, W, dY, M, K);
327 #endif
328 }
void gemv_q4_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.

References gemv_q4_0_backward_ref().

Referenced by gemm_q4_0_backward().

◆ gemv_q4_0_backward_ref()

void gemv_q4_0_backward_ref ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  K 
)

Backward pass: compute input gradient.

Parameters
dXOutput gradient w.r.t. input [K]
WWeight matrix in Q4_0 format [M x K]
dYGradient w.r.t. output [M]
MNumber of output rows
KNumber of columns (input dimension)

Definition at line 230 of file gemm_kernels_q4_0.c.

234 {
235  const block_q4_0 *blocks = (const block_q4_0 *)W;
236  const int blocks_per_row = K / QK4_0;
237 
238  /* Zero output gradient */
239  memset(dX, 0, K * sizeof(float));
240 
241  /* Accumulate: dX += W^T @ dY */
242  for (int row = 0; row < M; row++) {
243  const float dy = dY[row];
244 
245  for (int b = 0; b < blocks_per_row; b++) {
246  const block_q4_0 *block = &blocks[row * blocks_per_row + b];
247  const float d = CK_FP16_TO_FP32(block->d);
248  float *dxp = &dX[b * QK4_0];
249 
250  for (int i = 0; i < QK4_0 / 2; i++) {
251  const uint8_t packed = block->qs[i];
252  const int8_t q0 = (packed & 0x0F) - 8;
253  const int8_t q1 = (packed >> 4) - 8;
254 
255  dxp[2*i + 0] += d * (float)q0 * dy;
256  dxp[2*i + 1] += d * (float)q1 * dy;
257  }
258  }
259  }
260 }

References CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.

Referenced by gemv_q4_0_backward().

◆ gemv_q4_0_ref()

void gemv_q4_0_ref ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Matrix-vector multiply with Q4_0 weights (scalar reference)

Parameters
yOutput vector [M]
WWeight matrix in Q4_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 49 of file gemm_kernels_q4_0.c.

53 {
54  const block_q4_0 *blocks = (const block_q4_0 *)W;
55  const int blocks_per_row = K / QK4_0;
56 
57  for (int row = 0; row < M; row++) {
58  float sum = 0.0f;
59 
60  for (int b = 0; b < blocks_per_row; b++) {
61  const block_q4_0 *block = &blocks[row * blocks_per_row + b];
62  const float d = CK_FP16_TO_FP32(block->d);
63  const float *xp = &x[b * QK4_0];
64 
65  for (int i = 0; i < QK4_0 / 2; i++) {
66  const uint8_t packed = block->qs[i];
67  const int8_t q0 = (packed & 0x0F) - 8;
68  const int8_t q1 = (packed >> 4) - 8;
69 
70  sum += d * (float)q0 * xp[2*i + 0];
71  sum += d * (float)q1 * xp[2*i + 1];
72  }
73  }
74 
75  y[row] = sum;
76  }
77 }

References CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.

Referenced by gemv_q4_0().