← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q8_0.c File Reference

GEMM/GEMV kernels with Q8_0 quantized weights. More...

#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include "ckernel_quant.h"
#include "ck_features.h"

Go to the source code of this file.

Functions

float dot_q8_0 (const void *w_q8_0, const float *x, int K)
 
void gemm_nt_q8_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More...
 
void gemm_q8_0 (float *Y, const void *W, const float *X, int M, int N, int K)
 Matrix-matrix multiply with Q8_0 weights. More...
 
void gemm_q8_0_backward (float *dX, const void *W, const float *dY, int M, int N, int K)
 Batched backward pass. More...
 
void gemv_q8_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV for Q8_0 weights based on CPU features. More...
 
void gemv_q8_0_backward (float *dX, const void *W, const float *dY, int M, int K)
 Auto-dispatch backward. More...
 
void gemv_q8_0_backward_ref (float *dX, const void *W, const float *dY, int M, int K)
 Backward pass: compute input gradient (scalar reference) More...
 
void gemv_q8_0_parallel_simd (float *y, const void *W, const float *x, int M, int K, int ith, int nth)
 Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching. More...
 
void gemv_q8_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K)
 Matrix-vector multiply with Q8_0 weights and Q8_0 input. More...
 
void gemv_q8_0_q8_0_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 Parallel reference GEMV for Q8_0 x Q8_0. More...
 
void gemv_q8_0_q8_0_parallel_simd (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching. More...
 
void gemv_q8_0_ref (float *y, const void *W, const float *x, int M, int K)
 Matrix-vector multiply with Q8_0 weights (scalar reference) More...
 
void quantize_batch_q8_0 (const float *x, void *vy, int num_rows, int k)
 Batch quantize FP32 to Q8_0 format (row-major output) More...
 
void quantize_batch_q8_k (const float *x, void *vy, int num_rows, int k)
 Batch quantize FP32 to Q8_K format (row-major output) More...
 
void quantize_row_q8_0 (const float *x, void *vy, int k)
 Quantize FP32 to Q8_0 format (scalar reference) More...
 
void quantize_row_q8_k (const float *x, void *vy, int k)
 
void vec_dot_q8_0_q8_0 (int n, float *s, const void *vx, const void *vy)
 Auto-dispatch quantized dot product Q8_0 x Q8_0. More...
 
void vec_dot_q8_0_q8_0_ref (int n, float *s, const void *vx, const void *vy)
 Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference) More...
 

Detailed Description

GEMM/GEMV kernels with Q8_0 quantized weights.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Q8_0 Format:

  • 32 weights per block
  • 1 FP16 scale per block
  • 34 bytes per 32 weights = 8.5 bits/weight
  • Weights stored as signed 8-bit integers

Operations: Forward: Y = W @ X (W is Q8_0, X and Y are FP32) Backward: dX = W^T @ dY (gradient w.r.t. input)

Note: Q8_0 is often used for activation quantization or as an intermediate format. Higher precision than Q4_0/Q4_K.

Definition in file gemm_kernels_q8_0.c.

Function Documentation

◆ dot_q8_0()

float dot_q8_0 ( const void *  w_q8_0,
const float *  x,
int  K 
)

Definition at line 834 of file gemm_kernels_q8_0.c.

835 {
836  float result;
837  gemv_q8_0(&result, w_q8_0, x, 1, K);
838  return result;
839 }
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.

References gemv_q8_0().

◆ gemm_nt_q8_0()

void gemm_nt_q8_0 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.

Parameters
AInput matrix [M x K], row-major FP32
BWeight matrix in Q8_0 format, [N x K] stored row-major
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension (number of rows in B)
KInput dimension

Definition at line 681 of file gemm_kernels_q8_0.c.

686 {
687  /* Use GEMV dispatch which selects AVX/SSE/scalar based on CPU */
688  for (int m = 0; m < M; m++) {
689  gemv_q8_0(&C[m * N], B, &A[m * K], N, K);
690  if (bias) {
691  for (int n = 0; n < N; n++) C[m * N + n] += bias[n];
692  }
693  }
694  return;
695 
696  const block_q8_0 *blocks = (const block_q8_0 *)B;
697  const int blocks_per_row = K / QK8_0;
698 
699  for (int m = 0; m < M; m++) {
700  const float *a_row = &A[m * K];
701 
702  for (int n = 0; n < N; n++) {
703  float sum = 0.0f;
704 
705  for (int b = 0; b < blocks_per_row; b++) {
706  const block_q8_0 *block = &blocks[n * blocks_per_row + b];
707  const float d = CK_FP16_TO_FP32(block->d);
708  const float *ap = &a_row[b * QK8_0];
709 
710  for (int i = 0; i < QK8_0; i++) {
711  sum += d * (float)block->qs[i] * ap[i];
712  }
713  }
714 
715  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
716  }
717  }
718 }
#define CK_FP16_TO_FP32(x)
#define QK8_0
#define C(color)
Definition: show_config.c:39
int8_t qs[32]

References C, CK_FP16_TO_FP32, block_q8_0::d, gemv_q8_0(), QK8_0, and block_q8_0::qs.

Referenced by ck_gemm_nt_quant(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ gemm_q8_0()

void gemm_q8_0 ( float *  Y,
const void *  W,
const float *  X,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply with Q8_0 weights.

Definition at line 656 of file gemm_kernels_q8_0.c.

660 {
661  for (int n = 0; n < N; n++) {
662  gemv_q8_0(&Y[n * M], W, &X[n * K], M, K);
663  }
664 }

References gemv_q8_0().

◆ gemm_q8_0_backward()

void gemm_q8_0_backward ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  N,
int  K 
)

Batched backward pass.

Definition at line 820 of file gemm_kernels_q8_0.c.

824 {
825  for (int n = 0; n < N; n++) {
826  gemv_q8_0_backward(&dX[n * K], W, &dY[n * M], M, K);
827  }
828 }
void gemv_q8_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.

References gemv_q8_0_backward().

◆ gemv_q8_0()

void gemv_q8_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV for Q8_0 weights based on CPU features.

Dispatch priority (best available):

  1. AVX-512 (512-bit vectors) - Intel Skylake-X+
  2. AVX2+FMA (256-bit vectors) - Intel Haswell+
  3. AVX (256-bit vectors) - Intel Sandy Bridge+
  4. SSE4.1 (128-bit vectors) - Intel Nehalem+
  5. Reference (scalar) - Fallback

Uses ck_features.h for standardized feature detection.

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of input columns (hidden dimension)

Definition at line 630 of file gemm_kernels_q8_0.c.

634 {
635 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
636 #if defined(__AVX512F__)
637  gemv_q8_0_avx512(y, W, x, M, K);
638 #elif defined(__AVX2__)
639  gemv_q8_0_avx2(y, W, x, M, K);
640 #elif defined(__AVX__)
641  gemv_q8_0_avx(y, W, x, M, K);
642 #elif defined(__SSE4_1__)
643  gemv_q8_0_sse(y, W, x, M, K);
644 #else
645  gemv_q8_0_ref(y, W, x, M, K);
646 #endif
647 }
void gemv_q8_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q8_0 weights (scalar reference)

References gemv_q8_0_ref().

Referenced by dot_q8_0(), gemm_nt_q8_0(), and gemm_q8_0().

◆ gemv_q8_0_backward()

void gemv_q8_0_backward ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  K 
)

Auto-dispatch backward.

Definition at line 805 of file gemm_kernels_q8_0.c.

809 {
810 #ifdef __AVX512F__
811  gemv_q8_0_backward_avx512(dX, W, dY, M, K);
812 #else
813  gemv_q8_0_backward_ref(dX, W, dY, M, K);
814 #endif
815 }
void gemv_q8_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)

References gemv_q8_0_backward_ref().

Referenced by gemm_q8_0_backward().

◆ gemv_q8_0_backward_ref()

void gemv_q8_0_backward_ref ( float *  dX,
const void *  W,
const float *  dY,
int  M,
int  K 
)

Backward pass: compute input gradient (scalar reference)

Parameters
dXOutput gradient w.r.t. input [K]
WWeight matrix in Q8_0 format [M x K]
dYGradient w.r.t. output [M]
MNumber of output rows
KNumber of columns (input dimension)

Definition at line 733 of file gemm_kernels_q8_0.c.

737 {
738  const block_q8_0 *blocks = (const block_q8_0 *)W;
739  const int blocks_per_row = K / QK8_0;
740 
741  /* Zero output gradient */
742  memset(dX, 0, K * sizeof(float));
743 
744  /* Accumulate: dX += W^T @ dY */
745  for (int row = 0; row < M; row++) {
746  const float dy = dY[row];
747 
748  for (int b = 0; b < blocks_per_row; b++) {
749  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
750  const float d = CK_FP16_TO_FP32(block->d);
751  float *dxp = &dX[b * QK8_0];
752 
753  for (int i = 0; i < QK8_0; i++) {
754  dxp[i] += d * (float)block->qs[i] * dy;
755  }
756  }
757  }
758 }

References CK_FP16_TO_FP32, block_q8_0::d, QK8_0, and block_q8_0::qs.

Referenced by gemv_q8_0_backward().

◆ gemv_q8_0_parallel_simd()

void gemv_q8_0_parallel_simd ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching.

Definition at line 1153 of file gemm_kernels_q8_0.c.

1158 {
1159  if (!y || !W || !x || M <= 0 || K <= 0) return;
1160  if (ith < 0 || nth <= 0 || ith >= nth) return;
1161 
1162  const int dr = (M + nth - 1) / nth;
1163  const int r0 = dr * ith;
1164  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1165 
1166  if (r0 >= M) return;
1167 
1168  const block_q8_0 *blocks = (const block_q8_0 *)W;
1169  const int blocks_per_row = K / QK8_0;
1170 
1171 #if defined(__AVX__) || defined(__SSE4_1__)
1172  const int PREFETCH_ROWS = 4;
1173  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1174  const char *row_ptr = (const char *)(blocks + (r0 + p) * blocks_per_row);
1175  _mm_prefetch(row_ptr, _MM_HINT_T0);
1176  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1177  }
1178 
1179  for (int row = r0; row < r1; ++row) {
1180  if (row + PREFETCH_ROWS < r1) {
1181  const char *pf = (const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1182  _mm_prefetch(pf, _MM_HINT_T0);
1183  _mm_prefetch(pf + 64, _MM_HINT_T0);
1184  }
1185 
1186  /* Dispatch to best available SIMD for single row */
1187 #if defined(__AVX512F__)
1188  gemv_q8_0_avx512(&y[row],
1189  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1190  x, 1, K);
1191 #elif defined(__AVX2__)
1192  gemv_q8_0_avx2(&y[row],
1193  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1194  x, 1, K);
1195 #elif defined(__AVX__)
1196  gemv_q8_0_avx(&y[row],
1197  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1198  x, 1, K);
1199 #elif defined(__SSE4_1__)
1200  gemv_q8_0_sse(&y[row],
1201  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1202  x, 1, K);
1203 #else
1204  gemv_q8_0_ref(&y[row],
1205  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1206  x, 1, K);
1207 #endif
1208  }
1209 #else
1210  for (int row = r0; row < r1; row++) {
1211  gemv_q8_0_ref(&y[row],
1212  (const char *)blocks + row * blocks_per_row * sizeof(block_q8_0),
1213  x, 1, K);
1214  }
1215 #endif
1216 }

References gemv_q8_0_ref(), and QK8_0.

◆ gemv_q8_0_q8_0()

void gemv_q8_0_q8_0 ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Matrix-vector multiply with Q8_0 weights and Q8_0 input.

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
x_q8Input vector in Q8_0 format [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 1042 of file gemm_kernels_q8_0.c.

1046 {
1047  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1048  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1049  const int blocks_per_row = K / QK8_0;
1050 
1051  for (int row = 0; row < M; row++) {
1052  vec_dot_q8_0_q8_0(K, &y[row],
1053  &w_blocks[row * blocks_per_row],
1054  x_blocks);
1055  }
1056 }
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.

References QK8_0, and vec_dot_q8_0_q8_0().

Referenced by ck_test_gemv_q8_0(), and ck_test_gemv_q8_0_q8_0().

◆ gemv_q8_0_q8_0_parallel()

void gemv_q8_0_q8_0_parallel ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel reference GEMV for Q8_0 x Q8_0.

Definition at line 1068 of file gemm_kernels_q8_0.c.

1073 {
1074  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1075  if (ith < 0 || nth <= 0 || ith >= nth) return;
1076 
1077  const int dr = (M + nth - 1) / nth;
1078  const int r0 = dr * ith;
1079  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1080 
1081  if (r0 >= M) return;
1082 
1083  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1084  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1085  const int blocks_per_row = K / QK8_0;
1086 
1087  for (int row = r0; row < r1; row++) {
1088  vec_dot_q8_0_q8_0(K, &y[row],
1089  &w_blocks[row * blocks_per_row],
1090  x_blocks);
1091  }
1092 }

References QK8_0, and vec_dot_q8_0_q8_0().

◆ gemv_q8_0_q8_0_parallel_simd()

void gemv_q8_0_q8_0_parallel_simd ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching.

Each thread processes rows [r0, r1) where r0 = ith * ceil(M/nth). Prefetches upcoming weight rows to hide memory latency.

Definition at line 1100 of file gemm_kernels_q8_0.c.

1105 {
1106  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1107  if (ith < 0 || nth <= 0 || ith >= nth) return;
1108 
1109  const int dr = (M + nth - 1) / nth;
1110  const int r0 = dr * ith;
1111  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1112 
1113  if (r0 >= M) return;
1114 
1115  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1116  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1117  const int blocks_per_row = K / QK8_0;
1118 
1119 #if defined(__AVX__) || defined(__SSE4_1__)
1120  /* Prefetch first few rows */
1121  const int PREFETCH_ROWS = 4;
1122  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1123  const char *row_ptr = (const char *)(w_blocks + (r0 + p) * blocks_per_row);
1124  _mm_prefetch(row_ptr, _MM_HINT_T0);
1125  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1126  }
1127 
1128  for (int row = r0; row < r1; ++row) {
1129  /* Prefetch upcoming rows */
1130  if (row + PREFETCH_ROWS < r1) {
1131  const char *pf = (const char *)(w_blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1132  _mm_prefetch(pf, _MM_HINT_T0);
1133  _mm_prefetch(pf + 64, _MM_HINT_T0);
1134  }
1135 
1136  vec_dot_q8_0_q8_0(K, &y[row],
1137  &w_blocks[row * blocks_per_row],
1138  x_blocks);
1139  }
1140 #else
1141  /* Fallback: no prefetching */
1142  for (int row = r0; row < r1; row++) {
1143  vec_dot_q8_0_q8_0(K, &y[row],
1144  &w_blocks[row * blocks_per_row],
1145  x_blocks);
1146  }
1147 #endif
1148 }

References QK8_0, and vec_dot_q8_0_q8_0().

◆ gemv_q8_0_ref()

void gemv_q8_0_ref ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Matrix-vector multiply with Q8_0 weights (scalar reference)

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 252 of file gemm_kernels_q8_0.c.

256 {
257  const block_q8_0 *blocks = (const block_q8_0 *)W;
258  const int blocks_per_row = K / QK8_0;
259 
260  for (int row = 0; row < M; row++) {
261  float sum = 0.0f;
262 
263  for (int b = 0; b < blocks_per_row; b++) {
264  const block_q8_0 *block = &blocks[row * blocks_per_row + b];
265  const float d = CK_FP16_TO_FP32(block->d);
266  const float *xp = &x[b * QK8_0];
267 
268  for (int i = 0; i < QK8_0; i++) {
269  sum += d * (float)block->qs[i] * xp[i];
270  }
271  }
272 
273  y[row] = sum;
274  }
275 }

References CK_FP16_TO_FP32, block_q8_0::d, QK8_0, and block_q8_0::qs.

Referenced by gemv_q8_0(), and gemv_q8_0_parallel_simd().

◆ quantize_batch_q8_0()

void quantize_batch_q8_0 ( const float *  x,
void *  vy,
int  num_rows,
int  k 
)

Batch quantize FP32 to Q8_0 format (row-major output)

Quantizes multiple rows of FP32 data to Q8_0 format, placing each row's Q8_0 output at the correct byte offset for GEMM compatibility.

Memory layout: Input: [num_rows, k] FP32, row-major (stride = k * sizeof(float)) Output: [num_rows, q8_row_bytes] Q8_0, row-major (stride = q8_row_bytes)

where q8_row_bytes = (k / 32) * sizeof(block_q8_0) = (k / 32) * 34

Parameters
xInput FP32 values [num_rows * k]
vyOutput Q8_0 blocks [num_rows * (k/32) blocks]
num_rowsNumber of rows (batch size / tokens)
kElements per row (must be multiple of 32)

Definition at line 192 of file gemm_kernels_q8_0.c.

193 {
194  const size_t row_bytes_in = (size_t)k * sizeof(float);
195  const size_t row_bytes_out = (size_t)(k / QK8_0) * sizeof(block_q8_0);
196 
197  uint8_t *out = (uint8_t *)vy;
198  const uint8_t *in = (const uint8_t *)x;
199 
200  for (int row = 0; row < num_rows; ++row) {
202  (const float *)(in + row * row_bytes_in),
203  (void *)(out + row * row_bytes_out),
204  k
205  );
206  }
207 }
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References QK8_0, and quantize_row_q8_0().

◆ quantize_batch_q8_k()

void quantize_batch_q8_k ( const float *  x,
void *  vy,
int  num_rows,
int  k 
)

Batch quantize FP32 to Q8_K format (row-major output)

Same as quantize_batch_q8_0 but for Q8_K format (super-blocks).

Parameters
xInput FP32 values [num_rows * k]
vyOutput Q8_K blocks
num_rowsNumber of rows (batch size / tokens)
kElements per row (must be multiple of 256)

Definition at line 219 of file gemm_kernels_q8_0.c.

220 {
221  /* Q8_K: 256 elements per super-block, each block is larger */
222  const size_t row_bytes_in = (size_t)k * sizeof(float);
223  /* Q8_K block size = 2 (d) + 256 (qs) + 32 (bsums/2) = ~274 bytes for 256 elements */
224  /* Actual: sizeof(block_q8_K) from ckernel_quant.h */
225  const size_t row_bytes_out = (size_t)(k / 256) * sizeof(block_q8_K);
226 
227  uint8_t *out = (uint8_t *)vy;
228  const uint8_t *in = (const uint8_t *)x;
229 
230  for (int row = 0; row < num_rows; ++row) {
232  (const float *)(in + row * row_bytes_in),
233  (void *)(out + row * row_bytes_out),
234  k
235  );
236  }
237 }
void quantize_row_q8_k(const float *x, void *vy, int k)

References quantize_row_q8_k().

◆ quantize_row_q8_0()

void quantize_row_q8_0 ( const float *  x,
void *  vy,
int  k 
)

Quantize FP32 to Q8_0 format (scalar reference)

Parameters
xInput FP32 values
vyOutput Q8_0 blocks
kNumber of elements (must be multiple of 32)

Definition at line 59 of file gemm_kernels_q8_0.c.

60 {
61  block_q8_0 *y = (block_q8_0 *)vy;
62  const int nb = k / QK8_0; /* QK8_0 = 32 */
63 
64 #if defined(__AVX__)
65  const __m256 sign_bit = _mm256_set1_ps(-0.0f);
66  const __m256 v_half = _mm256_set1_ps(0.5f);
67  const __m256 v_min = _mm256_set1_ps(-127.0f);
68  const __m256 v_max = _mm256_set1_ps(127.0f);
69 
70  for (int i = 0; i < nb; i++) {
71  __m256 v0 = _mm256_loadu_ps(x + 0);
72  __m256 v1 = _mm256_loadu_ps(x + 8);
73  __m256 v2 = _mm256_loadu_ps(x + 16);
74  __m256 v3 = _mm256_loadu_ps(x + 24);
75  x += QK8_0;
76 
77  __m256 max_abs = _mm256_andnot_ps(sign_bit, v0);
78  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v1));
79  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v2));
80  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v3));
81 
82  __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
83  _mm256_castps256_ps128(max_abs));
84  max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
85  max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
86  const float max_scalar = _mm_cvtss_f32(max4);
87 
88  const float d = max_scalar / 127.0f;
89  const float id = max_scalar != 0.0f ? 127.0f / max_scalar : 0.0f;
90  y[i].d = CK_FP32_TO_FP16(d);
91 
92  const __m256 mul = _mm256_set1_ps(id);
93  v0 = _mm256_mul_ps(v0, mul);
94  v1 = _mm256_mul_ps(v1, mul);
95  v2 = _mm256_mul_ps(v2, mul);
96  v3 = _mm256_mul_ps(v3, mul);
97 
98  v0 = _mm256_min_ps(_mm256_max_ps(v0, v_min), v_max);
99  v1 = _mm256_min_ps(_mm256_max_ps(v1, v_min), v_max);
100  v2 = _mm256_min_ps(_mm256_max_ps(v2, v_min), v_max);
101  v3 = _mm256_min_ps(_mm256_max_ps(v3, v_min), v_max);
102 
103  /* Round half away from zero to match the scalar path */
104  v0 = _mm256_add_ps(v0, _mm256_or_ps(_mm256_and_ps(v0, sign_bit), v_half));
105  v1 = _mm256_add_ps(v1, _mm256_or_ps(_mm256_and_ps(v1, sign_bit), v_half));
106  v2 = _mm256_add_ps(v2, _mm256_or_ps(_mm256_and_ps(v2, sign_bit), v_half));
107  v3 = _mm256_add_ps(v3, _mm256_or_ps(_mm256_and_ps(v3, sign_bit), v_half));
108 
109  __m256i i0 = _mm256_cvttps_epi32(v0);
110  __m256i i1 = _mm256_cvttps_epi32(v1);
111  __m256i i2 = _mm256_cvttps_epi32(v2);
112  __m256i i3 = _mm256_cvttps_epi32(v3);
113 
114 #if defined(__AVX2__)
115  i0 = _mm256_packs_epi32(i0, i1);
116  i2 = _mm256_packs_epi32(i2, i3);
117  i0 = _mm256_packs_epi16(i0, i2);
118 
119  const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
120  i0 = _mm256_permutevar8x32_epi32(i0, perm);
121  _mm256_storeu_si256((__m256i *)y[i].qs, i0);
122 #else
123  __m128i ni0 = _mm256_castsi256_si128(i0);
124  __m128i ni1 = _mm256_extractf128_si256(i0, 1);
125  __m128i ni2 = _mm256_castsi256_si128(i1);
126  __m128i ni3 = _mm256_extractf128_si256(i1, 1);
127  __m128i ni4 = _mm256_castsi256_si128(i2);
128  __m128i ni5 = _mm256_extractf128_si256(i2, 1);
129  __m128i ni6 = _mm256_castsi256_si128(i3);
130  __m128i ni7 = _mm256_extractf128_si256(i3, 1);
131 
132  ni0 = _mm_packs_epi32(ni0, ni1);
133  ni2 = _mm_packs_epi32(ni2, ni3);
134  ni4 = _mm_packs_epi32(ni4, ni5);
135  ni6 = _mm_packs_epi32(ni6, ni7);
136 
137  ni0 = _mm_packs_epi16(ni0, ni2);
138  ni4 = _mm_packs_epi16(ni4, ni6);
139 
140  _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
141  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
142 #endif
143  }
144 #else
145  for (int i = 0; i < nb; i++) {
146  const float *xb = x + i * QK8_0;
147 
148  /* Find max absolute value in block */
149  float amax = 0.0f;
150  for (int j = 0; j < QK8_0; j++) {
151  float av = xb[j] >= 0 ? xb[j] : -xb[j];
152  if (av > amax) amax = av;
153  }
154 
155  /* Compute scale: d = max / 127 */
156  float d = amax / 127.0f;
157  float id = d != 0.0f ? 127.0f / amax : 0.0f;
158 
159  /* Store scale as FP16 */
160  y[i].d = CK_FP32_TO_FP16(d);
161 
162  /* Quantize values */
163  for (int j = 0; j < QK8_0; j++) {
164  float v = xb[j] * id;
165  /* Round to nearest int and clamp to [-127, 127] */
166  int q = (int)(v + (v >= 0 ? 0.5f : -0.5f));
167  if (q > 127) q = 127;
168  if (q < -127) q = -127;
169  y[i].qs[j] = (int8_t)q;
170  }
171  }
172 #endif
173 }
#define CK_FP32_TO_FP16(x)
int32_t id
Definition: tokenizer.h:315

References CK_FP32_TO_FP16, block_q8_0::d, id, QK8_0, and block_q8_0::qs.

Referenced by ck_test_gemm_q5_0(), ck_test_gemm_q8_0(), ck_test_gemv_q5_0(), ck_test_gemv_q5_0_q8_0(), ck_test_gemv_q8_0(), ck_test_gemv_q8_0_q8_0(), fused_mlp_swiglu_prefill_w1w2_quant(), fused_rmsnorm_qkv_prefill_head_major_quant(), gemv_fused_q5_0_bias_parallel_omp(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), mega_fused_attention_decode_q5_0(), mega_fused_attention_decode_q5_0_parallel_simd(), quantize_attn_out_head_major_q8_0(), and quantize_batch_q8_0().

◆ quantize_row_q8_k()

void quantize_row_q8_k ( const float *  x,
void *  vy,
int  k 
)

Definition at line 107 of file gemm_kernels_q4k_q8k.c.

107  {
108 #if defined(__SSE4_1__)
109  quantize_row_q8_k_sse(x, vy, k);
110 #else
111  quantize_row_q8_k_ref(x, vy, k);
112 #endif
113 }
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
void quantize_row_q8_k_ref(const float *x, void *vy, int k)

Referenced by quantize_batch_q8_k().

◆ vec_dot_q8_0_q8_0()

void vec_dot_q8_0_q8_0 ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Auto-dispatch quantized dot product Q8_0 x Q8_0.

Definition at line 1013 of file gemm_kernels_q8_0.c.

1014 {
1015 #ifdef __AVX512F__
1016  vec_dot_q8_0_q8_0_avx512(n, s, vx, vy);
1017 #elif defined(__AVX__)
1018  vec_dot_q8_0_q8_0_avx(n, s, vx, vy);
1019 #elif defined(__SSE4_1__)
1020  vec_dot_q8_0_q8_0_sse(n, s, vx, vy);
1021 #else
1022  vec_dot_q8_0_q8_0_ref(n, s, vx, vy);
1023 #endif
1024 }
void vec_dot_q8_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)

References vec_dot_q8_0_q8_0_ref().

Referenced by ck_test_vec_dot_q8_0_q8_0(), gemv_q8_0_from_fp32(), gemv_q8_0_q8_0(), gemv_q8_0_q8_0_parallel(), gemv_q8_0_q8_0_parallel_omp(), gemv_q8_0_q8_0_parallel_simd(), and out_proj_head_major_q8_0_q8_0().

◆ vec_dot_q8_0_q8_0_ref()

void vec_dot_q8_0_q8_0_ref ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)

Parameters
nNumber of elements (must be multiple of 32)
sOutput: scalar dot product result
vxQ8_0 quantized weights
vyQ8_0 quantized input

Definition at line 863 of file gemm_kernels_q8_0.c.

864 {
865  const int qk = QK8_0; /* 32 */
866  const int nb = n / qk;
867 
868  const block_q8_0 *x = (const block_q8_0 *)vx;
869  const block_q8_0 *y = (const block_q8_0 *)vy;
870 
871  float sumf = 0.0f;
872 
873  for (int ib = 0; ib < nb; ib++) {
874  int sumi = 0;
875 
876  for (int j = 0; j < qk; j++) {
877  sumi += x[ib].qs[j] * y[ib].qs[j];
878  }
879 
880  sumf += sumi * (CK_FP16_TO_FP32(x[ib].d) * CK_FP16_TO_FP32(y[ib].d));
881  }
882 
883  *s = sumf;
884 }

References CK_FP16_TO_FP32, QK8_0, and block_q8_0::qs.

Referenced by vec_dot_q8_0_q8_0().