← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k_q8k.c File Reference

Q4_K (weights) x Q8_K (activations) kernels for inference. More...

#include <assert.h>
#include <math.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

static int ck_nearest_int (float fval)
 
static float dot_q4_k_q8_k_ref (const block_q4_K *w, const block_q8_K *x, int k)
 
void gemm_nt_q4_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_q4_k_q8_k (float *Y, const void *W, const void *X_q8, int M, int N, int K)
 
void gemm_q4_k_q8_k_ref (float *Y, const void *W, const void *X_q8, int M, int N, int K)
 
void gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_avx (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_avx2 (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 
void gemv_q4_k_q8_k_ref (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_sse (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_vnni (float *y, const void *W, const void *x_q8, int M, int K)
 
void quantize_row_q8_k (const float *x, void *vy, int k)
 
void quantize_row_q8_k_ref (const float *x, void *vy, int k)
 
void quantize_row_q8_k_sse (const float *x, void *vy, int k)
 

Detailed Description

Q4_K (weights) x Q8_K (activations) kernels for inference.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Implements decode-style matvec/matmul where weights are Q4_K and the activations are quantized on-the-fly to Q8_K. This is inference-only; no backward pass is provided here.

Definition in file gemm_kernels_q4k_q8k.c.

Function Documentation

◆ ck_nearest_int()

static int ck_nearest_int ( float  fval)
inlinestatic

Definition at line 45 of file gemm_kernels_q4k_q8k.c.

45  {
46  /* Bit-level round-to-nearest from llama.cpp (fast + deterministic). */
47  float val = fval + 12582912.f;
48  int i;
49  memcpy(&i, &val, sizeof(int));
50  return (i & 0x007fffff) - 0x00400000;
51 }

Referenced by quantize_row_q8_k_ref().

◆ dot_q4_k_q8_k_ref()

static float dot_q4_k_q8_k_ref ( const block_q4_K w,
const block_q8_K x,
int  k 
)
static

Definition at line 115 of file gemm_kernels_q4k_q8k.c.

118 {
119  const int nb = k / QK_K;
120  float sumf = 0.0f;
121 
122  for (int i = 0; i < nb; ++i) {
123  uint8_t sc[8], m_val[8];
124  unpack_q4_k_scales(w[i].scales, sc, m_val);
125 
126  const float d = CK_FP16_TO_FP32(w[i].d) * x[i].d;
127  const float dmin = CK_FP16_TO_FP32(w[i].dmin) * x[i].d;
128 
129  /* Q4_K layout: process 64 elements at a time
130  * - Low nibbles of qs[0..31] → elements 0..31 → uses sc[0], m[0]
131  * - High nibbles of qs[0..31] → elements 32..63 → uses sc[1], m[1]
132  * - Low nibbles of qs[32..63] → elements 64..95 → uses sc[2], m[2]
133  * - etc.
134  */
135  int is = 0;
136  int q_offset = 0;
137 
138  for (int j = 0; j < QK_K; j += 64) {
139  const uint8_t *qs = &w[i].qs[q_offset];
140  const int8_t *q8_lo = &x[i].qs[j]; /* Elements j to j+31 */
141  const int8_t *q8_hi = &x[i].qs[j + 32]; /* Elements j+32 to j+63 */
142 
143  /* Sum for low nibbles (elements j to j+31) */
144  int32_t sum_q4q8_lo = 0;
145  for (int l = 0; l < 32; ++l) {
146  int q4_val = qs[l] & 0x0F;
147  sum_q4q8_lo += q4_val * q8_lo[l];
148  }
149 
150  /* Sum for high nibbles (elements j+32 to j+63) */
151  int32_t sum_q4q8_hi = 0;
152  for (int l = 0; l < 32; ++l) {
153  int q4_val = qs[l] >> 4;
154  sum_q4q8_hi += q4_val * q8_hi[l];
155  }
156 
157  /* bsums: each bsum is 16 elements */
158  int32_t bsum_lo = (int32_t)x[i].bsums[j / 16] +
159  (int32_t)x[i].bsums[j / 16 + 1];
160  int32_t bsum_hi = (int32_t)x[i].bsums[(j + 32) / 16] +
161  (int32_t)x[i].bsums[(j + 32) / 16 + 1];
162 
163  /* Accumulate: d * sc * sum(q4*q8) - dmin * m * sum(q8) */
164  sumf += d * (float)sc[is] * (float)sum_q4q8_lo;
165  sumf -= dmin * (float)m_val[is] * (float)bsum_lo;
166  sumf += d * (float)sc[is + 1] * (float)sum_q4q8_hi;
167  sumf -= dmin * (float)m_val[is + 1] * (float)bsum_hi;
168 
169  q_offset += 32;
170  is += 2;
171  }
172  }
173 
174  return sumf;
175 }
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
#define QK_K
uint8_t qs[256/2]
int8_t qs[256]
int16_t bsums[256/16]

References block_q8_K::bsums, CK_FP16_TO_FP32, block_q8_K::d, QK_K, block_q4_K::qs, block_q8_K::qs, and unpack_q4_k_scales().

Referenced by gemv_q4_k_q8_k_parallel(), and gemv_q4_k_q8_k_ref().

◆ gemm_nt_q4_k_q8_k()

void gemm_nt_q4_k_q8_k ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 295 of file gemm_kernels_q4k_q8k.c.

300 {
301  if (!A_q8 || !B || !C) {
302  return;
303  }
304  if (M <= 0 || N <= 0 || K <= 0) {
305  return;
306  }
307 
308  gemm_q4_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
309 
310  if (!bias) {
311  return;
312  }
313 
314  for (int i = 0; i < M; ++i) {
315  float *row = C + (size_t)i * (size_t)N;
316  for (int j = 0; j < N; ++j) {
317  row[j] += bias[j];
318  }
319  }
320 }
void gemm_q4_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
#define C(color)
Definition: show_config.c:39

References C, and gemm_q4_k_q8_k().

Referenced by ck_attention_project_head_major_q4_k_q8_k(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_qkv_project_head_major_token_q4_k_q8_k(), ck_test_gemm_q4_k(), gemm_nt_q8_k_mlp_dispatch(), gemm_nt_q8_k_qkv_dispatch(), model_forward_prefill_impl(), and qwen2_0_5b_decode_forward_prefill_impl().

◆ gemm_q4_k_q8_k()

void gemm_q4_k_q8_k ( float *  Y,
const void *  W,
const void *  X_q8,
int  M,
int  N,
int  K 
)

Definition at line 277 of file gemm_kernels_q4k_q8k.c.

281 {
282  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
283  return;
284  }
285 
286  const block_q8_K *X = (const block_q8_K *)X_q8;
287  const int blocks_per_vec = K / QK_K;
288 
289  for (int n = 0; n < N; ++n) {
290  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
291  gemv_q4_k_q8_k(&Y[n * M], W, x_row, M, K);
292  }
293 }
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q4_k_q8_k(), and QK_K.

Referenced by gemm_nt_q4_k_q8_k().

◆ gemm_q4_k_q8_k_ref()

void gemm_q4_k_q8_k_ref ( float *  Y,
const void *  W,
const void *  X_q8,
int  M,
int  N,
int  K 
)

Definition at line 259 of file gemm_kernels_q4k_q8k.c.

263 {
264  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
265  return;
266  }
267 
268  const block_q8_K *X = (const block_q8_K *)X_q8;
269  const int blocks_per_vec = K / QK_K;
270 
271  for (int n = 0; n < N; ++n) {
272  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
273  gemv_q4_k_q8_k_ref(&Y[n * M], W, x_row, M, K);
274  }
275 }
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q4_k_q8_k_ref(), and QK_K.

◆ gemv_q4_k_q8_k()

void gemv_q4_k_q8_k ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 239 of file gemm_kernels_q4k_q8k.c.

243 {
244 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
245  /* VNNI: Best for decode (single token) - INT8 dot product acceleration */
246  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
247 #elif defined(__AVX2__)
248  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
249 #elif defined(__AVX__)
250  /* AVX version uses maddubs_epi16 (more efficient than SSE) */
251  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
252 #elif defined(__SSE4_1__)
253  gemv_q4_k_q8_k_sse(y, W, x_q8, M, K);
254 #else
255  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
256 #endif
257 }
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q4_k_q8_k_avx(), gemv_q4_k_q8_k_avx2(), gemv_q4_k_q8_k_ref(), gemv_q4_k_q8_k_sse(), and gemv_q4_k_q8_k_vnni().

Referenced by ck_test_gemv_q4_k(), fused_rmsnorm_linear_q4k(), gemm_q4_k_q8_k(), model_decode_token(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), and unfused_rmsnorm_linear_q4k_ref().

◆ gemv_q4_k_q8_k_avx()

void gemv_q4_k_q8_k_avx ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 251 of file gemm_kernels_q4k_avx.c.

255 {
256  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
257 }
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)

Referenced by gemv_q4_k_q8_k().

◆ gemv_q4_k_q8_k_avx2()

void gemv_q4_k_q8_k_avx2 ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 89 of file gemm_kernels_q4k_q8k_avx2.c.

93 {
94  /* TODO: Implement AVX2 version with correct Q4_K memory layout.
95  * For now, fall back to reference implementation which has been
96  * fixed to use the correct layout.
97  */
98  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
99 }
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)

Referenced by gemv_q4_k_q8_k().

◆ gemv_q4_k_q8_k_parallel()

void gemv_q4_k_q8_k_parallel ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Definition at line 206 of file gemm_kernels_q4k_q8k.c.

211 {
212  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
213  return;
214  }
215  if (ith < 0 || nth <= 0 || ith >= nth) {
216  return;
217  }
218 
219  /* Compute row range for this thread */
220  const int dr = (M + nth - 1) / nth;
221  const int r0 = dr * ith;
222  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
223 
224  if (r0 >= M) {
225  return; /* This thread has no work */
226  }
227 
228  const block_q4_K *blocks = (const block_q4_K *)W;
229  const block_q8_K *x = (const block_q8_K *)x_q8;
230  const int blocks_per_row = K / QK_K;
231 
232  /* Only process rows [r0, r1) */
233  for (int row = r0; row < r1; ++row) {
234  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
235  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
236  }
237 }
static float dot_q4_k_q8_k_ref(const block_q4_K *w, const block_q8_K *x, int k)

References dot_q4_k_q8_k_ref(), and QK_K.

Referenced by gemv_q4_k_q8_k_parallel_simd().

◆ gemv_q4_k_q8_k_ref()

void gemv_q4_k_q8_k_ref ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 177 of file gemm_kernels_q4k_q8k.c.

181 {
182  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
183  return;
184  }
185 
186  const block_q4_K *blocks = (const block_q4_K *)W;
187  const block_q8_K *x = (const block_q8_K *)x_q8;
188  const int blocks_per_row = K / QK_K;
189 
190  for (int row = 0; row < M; ++row) {
191  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
192  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
193  }
194 }

References dot_q4_k_q8_k_ref(), and QK_K.

Referenced by gemm_q4_k_q8_k_ref(), gemv_q4_k_q8_k(), gemv_q4_k_q8_k_amx(), gemv_q4_k_q8_k_avx(), gemv_q4_k_q8_k_avx2(), and gemv_q4_k_q8_k_vnni().

◆ gemv_q4_k_q8_k_sse()

void gemv_q4_k_q8_k_sse ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 33 of file gemm_kernels_q4k_sse.c.

37 {
38  const block_q4_K *blocks = (const block_q4_K *)W;
39  const block_q8_K *x = (const block_q8_K *)x_q8;
40  const int blocks_per_row = K / QK_K;
41 
42  const __m128i mask_low = _mm_set1_epi8(0x0F);
43 
44  for (int row = 0; row < M; ++row) {
45  float sumf = 0.0f;
46  const block_q4_K *w_row = blocks + row * blocks_per_row;
47 
48  for (int i = 0; i < blocks_per_row; ++i) {
49  const block_q4_K *b4 = &w_row[i];
50  const block_q8_K *b8 = &x[i];
51 
52  // Unpack scales (same as ref)
53  uint8_t sc[8], m_val[8];
54  unpack_q4_k_scales(b4->scales, sc, m_val);
55 
56  float d = CK_FP16_TO_FP32(b4->d) * b8->d;
57  float dmin = CK_FP16_TO_FP32(b4->dmin) * b8->d;
58 
59  int is = 0;
60  int q_offset = 0;
61 
62  // Process 4 chunks of 64 elements (256 total)
63  for (int j = 0; j < QK_K; j += 64) {
64  // We process 32 bytes of qs (covering 64 elements via low/high nibbles)
65  // We access qs[0..31] relative to q_offset
66 
67  // Accumulators for this 64-element chunk
68  __m128i acc_lo = _mm_setzero_si128();
69  __m128i acc_hi = _mm_setzero_si128();
70 
71  // Inner loop: 2 iters of 16 bytes (32 elements)
72  for (int l = 0; l < 32; l += 16) {
73  // Load 16 bytes of Q4
74  __m128i q4_vec = _mm_loadu_si128((const __m128i *)(b4->qs + q_offset + l));
75 
76  // Low nibbles -> correspond to q8_lo (elements j+l .. j+l+15)
77  __m128i q4_lo = _mm_and_si128(q4_vec, mask_low);
78 
79  // High nibbles -> correspond to q8_hi (elements j+32+l .. j+32+l+15)
80  __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_vec, 4), mask_low);
81 
82  // Load Q8
83  __m128i q8_lo_vec = _mm_loadu_si128((const __m128i *)(b8->qs + j + l));
84  __m128i q8_hi_vec = _mm_loadu_si128((const __m128i *)(b8->qs + j + 32 + l));
85 
86  // Expand and Multiply-Add: Q4(u8) * Q8(s8) -> i32
87  // Since Q4 is u8 and Q8 is s8, we use intermediate i16
88 
89  // LO PART
90  __m128i q4_lo_16_L = _mm_cvtepu8_epi16(q4_lo); // lower 8 -> 16
91  __m128i q8_lo_16_L = _mm_cvtepi8_epi16(q8_lo_vec);
92  __m128i prod_lo_L = _mm_madd_epi16(q4_lo_16_L, q8_lo_16_L); // i32
93  acc_lo = _mm_add_epi32(acc_lo, prod_lo_L);
94 
95  __m128i q4_lo_16_H = _mm_cvtepu8_epi16(_mm_srli_si128(q4_lo, 8)); // upper 8 -> 16
96  __m128i q8_lo_16_H = _mm_cvtepi8_epi16(_mm_srli_si128(q8_lo_vec, 8));
97  __m128i prod_lo_H = _mm_madd_epi16(q4_lo_16_H, q8_lo_16_H); // i32
98  acc_lo = _mm_add_epi32(acc_lo, prod_lo_H);
99 
100  // HI PART
101  __m128i q4_hi_16_L = _mm_cvtepu8_epi16(q4_hi);
102  __m128i q8_hi_16_L = _mm_cvtepi8_epi16(q8_hi_vec);
103  __m128i prod_hi_L = _mm_madd_epi16(q4_hi_16_L, q8_hi_16_L);
104  acc_hi = _mm_add_epi32(acc_hi, prod_hi_L);
105 
106  __m128i q4_hi_16_H = _mm_cvtepu8_epi16(_mm_srli_si128(q4_hi, 8));
107  __m128i q8_hi_16_H = _mm_cvtepi8_epi16(_mm_srli_si128(q8_hi_vec, 8));
108  __m128i prod_hi_H = _mm_madd_epi16(q4_hi_16_H, q8_hi_16_H);
109  acc_hi = _mm_add_epi32(acc_hi, prod_hi_H);
110  }
111 
112  int32_t sum_q4q8_lo = hsum_epi32_sse(acc_lo);
113  int32_t sum_q4q8_hi = hsum_epi32_sse(acc_hi);
114 
115  /* bsums: each bsum is 16 elements */
116  int32_t bsum_lo = (int32_t)b8->bsums[j / 16] +
117  (int32_t)b8->bsums[j / 16 + 1];
118  int32_t bsum_hi = (int32_t)b8->bsums[(j + 32) / 16] +
119  (int32_t)b8->bsums[(j + 32) / 16 + 1];
120 
121  sumf += d * (float)sc[is] * (float)sum_q4q8_lo;
122  sumf -= dmin * (float)m_val[is] * (float)bsum_lo;
123  sumf += d * (float)sc[is + 1] * (float)sum_q4q8_hi;
124  sumf -= dmin * (float)m_val[is + 1] * (float)bsum_hi;
125 
126  q_offset += 32;
127  is += 2;
128  }
129  }
130  y[row] = sumf;
131  }
132 }
static int32_t hsum_epi32_sse(__m128i v)
uint8_t scales[12]
ck_half dmin

Referenced by gemv_q4_k_q8_k().

◆ gemv_q4_k_q8_k_vnni()

void gemv_q4_k_q8_k_vnni ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 95 of file gemm_kernels_q4k_q8k_vnni.c.

99 {
100  /* TODO: Implement VNNI version with correct Q4_K memory layout.
101  * For now, fall back to reference implementation which has been
102  * fixed to use the correct layout.
103  */
104  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
105 }
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)

Referenced by gemv_q4_k_q8_k().

◆ quantize_row_q8_k()

void quantize_row_q8_k ( const float *  x,
void *  vy,
int  k 
)

Definition at line 107 of file gemm_kernels_q4k_q8k.c.

107  {
108 #if defined(__SSE4_1__)
109  quantize_row_q8_k_sse(x, vy, k);
110 #else
111  quantize_row_q8_k_ref(x, vy, k);
112 #endif
113 }
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
void quantize_row_q8_k_ref(const float *x, void *vy, int k)

References quantize_row_q8_k_ref(), and quantize_row_q8_k_sse().

Referenced by ck_attention_project_head_major_q4_k_q8_k(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_qkv_project_head_major_q4_k_q8_k(), ck_test_gemm_q4_k(), ck_test_gemm_q6_k(), ck_test_gemv_q4_k(), ck_test_quantize_q8_k(), decode_layer_parallel(), fused_mlp_swiglu_prefill_w1w2_quant(), fused_rmsnorm_qkv_prefill_head_major_quant(), gemm_nt_q5_0_sse_v2(), gemm_nt_q6_k_sse(), mlp_parallel(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), quantize_batch_q8_k(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), and unfused_rmsnorm_linear_q4k_ref().

◆ quantize_row_q8_k_ref()

void quantize_row_q8_k_ref ( const float *  x,
void *  vy,
int  k 
)

Definition at line 53 of file gemm_kernels_q4k_q8k.c.

53  {
54  if (!x || !vy || k <= 0) {
55  return;
56  }
57  assert(k % QK_K == 0);
58  const int nb = k / QK_K;
59  block_q8_K *y = (block_q8_K *)vy;
60 
61  for (int i = 0; i < nb; ++i) {
62  float max = 0.0f;
63  float amax = 0.0f;
64  for (int j = 0; j < QK_K; ++j) {
65  float ax = fabsf(x[j]);
66  if (ax > amax) {
67  amax = ax;
68  max = x[j];
69  }
70  }
71  if (!amax) {
72  y[i].d = 0.0f;
73  memset(y[i].qs, 0, sizeof(y[i].qs));
74  memset(y[i].bsums, 0, sizeof(y[i].bsums));
75  x += QK_K;
76  continue;
77  }
78 
79  const float iscale = -127.0f / max;
80  for (int j = 0; j < QK_K; ++j) {
81  int v = ck_nearest_int(iscale * x[j]);
82  if (v > 127) {
83  v = 127;
84  }
85  if (v < -128) {
86  v = -128;
87  }
88  y[i].qs[j] = (int8_t)v;
89  }
90 
91  for (int j = 0; j < QK_K / 16; ++j) {
92  int sum = 0;
93  const int8_t *qs = &y[i].qs[j * 16];
94  for (int ii = 0; ii < 16; ++ii) {
95  sum += qs[ii];
96  }
97  y[i].bsums[j] = (int16_t)sum;
98  }
99 
100  y[i].d = 1.0f / iscale;
101  x += QK_K;
102  }
103 }
static int ck_nearest_int(float fval)

References block_q8_K::bsums, ck_nearest_int(), block_q8_K::d, QK_K, and block_q8_K::qs.

Referenced by quantize_row_q8_k().

◆ quantize_row_q8_k_sse()

void quantize_row_q8_k_sse ( const float *  x,
void *  vy,
int  k 
)

Definition at line 29 of file quantize_row_q8_k_sse.c.

29  {
30  if (!x || !vy || k <= 0) {
31  return;
32  }
33  assert(k % QK_K == 0);
34  const int nb = k / QK_K;
35  block_q8_K *y = (block_q8_K *)vy;
36 
37  for (int i = 0; i < nb; ++i) {
38  float max = 0.0f;
39 
40  // SSE max absolute value
41  __m128 v_max = _mm_setzero_ps();
42  for (int j = 0; j < QK_K; j += 4) {
43  __m128 v = _mm_loadu_ps(x + j);
44  __m128 v_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), v);
45  v_max = _mm_max_ps(v_max, v_abs);
46  }
47 
48  // Horizontal max
49  v_max = _mm_max_ps(v_max, _mm_shuffle_ps(v_max, v_max, _MM_SHUFFLE(1, 0, 3, 2)));
50  v_max = _mm_max_ps(v_max, _mm_shuffle_ps(v_max, v_max, _MM_SHUFFLE(0, 1, 0, 1)));
51  _mm_store_ss(&max, v_max);
52 
53  if (max == 0.0f) {
54  y[i].d = 0.0f;
55  memset(y[i].qs, 0, sizeof(y[i].qs));
56  memset(y[i].bsums, 0, sizeof(y[i].bsums));
57  x += QK_K;
58  continue;
59  }
60 
61  const float iscale = -127.0f / max;
62  __m128 v_iscale = _mm_set1_ps(iscale);
63 
64  // Quantize and compute bsums in SSE
65  for (int j = 0; j < QK_K; j += 16) {
66  __m128 x0 = _mm_loadu_ps(x + j + 0);
67  __m128 x1 = _mm_loadu_ps(x + j + 4);
68  __m128 x2 = _mm_loadu_ps(x + j + 8);
69  __m128 x3 = _mm_loadu_ps(x + j + 12);
70 
71  __m128i q0 = _mm_cvtps_epi32(_mm_mul_ps(x0, v_iscale));
72  __m128i q1 = _mm_cvtps_epi32(_mm_mul_ps(x1, v_iscale));
73  __m128i q2 = _mm_cvtps_epi32(_mm_mul_ps(x2, v_iscale));
74  __m128i q3 = _mm_cvtps_epi32(_mm_mul_ps(x3, v_iscale));
75 
76  // Pack i32 -> i16 -> i8
77  __m128i q01 = _mm_packs_epi32(q0, q1);
78  __m128i q23 = _mm_packs_epi32(q2, q3);
79  __m128i q0123 = _mm_packs_epi16(q01, q23);
80 
81  _mm_storeu_si128((__m128i *)(y[i].qs + j), q0123);
82 
83  // Compute bsum for these 16 elements
84  // Each bsum[j/16] covers 16 elements
85  __m128i p01 = _mm_add_epi16(q01, q23);
86  p01 = _mm_add_epi16(p01, _mm_shuffle_epi32(p01, _MM_SHUFFLE(1, 0, 3, 2)));
87  p01 = _mm_add_epi16(p01, _mm_shufflelo_epi16(p01, _MM_SHUFFLE(1, 0, 3, 2)));
88  int16_t bsum = (int16_t)_mm_extract_epi16(p01, 0) + (int16_t)_mm_extract_epi16(p01, 1);
89  y[i].bsums[j / 16] = bsum;
90  }
91 
92  y[i].d = 1.0f / iscale;
93  x += QK_K;
94  }
95 }

Referenced by quantize_row_q8_k().