← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ck_parity_api.c
Go to the documentation of this file.
1 /**
2  * @file ck_parity_api.c
3  * @brief C-Kernel-Engine Parity Testing API Implementation
4  *
5  * Wraps CK kernels for parity testing against llama.cpp/ggml.
6  */
7 
8 #include "ck_parity_api.h"
9 #include "ckernel_quant.h"
10 #include <math.h>
11 #include <stdlib.h>
12 #include <string.h>
13 
14 /* External kernel function declarations */
15 
16 /* Dequantization kernels (from dequant_kernels.c) */
17 extern void dequant_q4_k_row(const void *src, float *dst, size_t n_elements);
18 extern void dequant_q6_k_row(const void *src, float *dst, size_t n_elements);
19 extern void dequant_q4_0_row(const void *src, float *dst, size_t n_elements);
20 extern void dequant_q5_1_row(const void *src, float *dst, size_t n_elements);
21 
22 /* Quantization kernels (from gemm_kernels_q4k_q8k.c) */
23 extern void quantize_row_q8_k(const float *x, void *vy, int k);
24 
25 /* GEMV/GEMM kernels */
26 extern void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K);
27 extern void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias,
28  float *C, int M, int N, int K);
29 
30 /* Q6_K x Q8_K kernels (from gemm_kernels_q6k_q8k.c) */
31 extern void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K);
32 extern void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias,
33  float *C, int M, int N, int K);
34 
35 /* Q8_0 x Q8_0 batch GEMM (from gemm_batch_int8.c) */
36 extern void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B_q8, const float *bias,
37  float *C, int M, int N, int K);
38 
39 /* Q5_0 x Q8_0 batch GEMM (from gemm_kernels_q5_0.c) */
40 extern void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias,
41  float *C, int M, int N, int K);
42 
43 /* Q5_K kernels (from gemm_kernels_q5_k.c) */
44 extern void gemv_q5_k(float *y, const void *W, const float *x, int M, int K);
45 extern void gemm_nt_q5_k(const float *A, const void *B, const float *bias,
46  float *C, int M, int N, int K);
47 
48 /* Q5_1 kernels (from gemm_kernels_q5_1.c) */
49 extern void gemv_q5_1(float *y, const void *W, const float *x, int M, int K);
50 extern void gemm_nt_q5_1(const float *A, const void *B, const float *bias,
51  float *C, int M, int N, int K);
52 
53 /* Q5_0 and Q8_0 GEMV kernels (from gemm_kernels_q5_0.c, gemm_kernels_q8_0.c) */
54 extern void gemv_q5_0(float *y, const void *W, const float *x, int M, int K);
55 extern void gemv_q8_0(float *y, const void *W, const float *x, int M, int K);
56 
57 /* Quantized dot product kernels for parity with llama.cpp */
58 extern void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K);
59 extern void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K);
60 
61 /* Direct vec_dot kernels (single dot product, not GEMV) */
62 extern void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy);
63 extern void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy);
64 
65 /* Q8_0 quantization (for input) */
66 extern void quantize_row_q8_0(const float *x, void *vy, int k);
67 
68 /* RMSNorm kernel (from rmsnorm_kernels.c) */
69 extern void rmsnorm_forward(const float *input, const float *gamma,
70  float *output, float *rstd_cache,
71  int tokens, int d_model, int aligned_embed_dim, float eps);
72 
73 /* RoPE kernels (from rope_kernels.c) */
74 extern void rope_forward_qk(float *q, float *k,
75  const float *cos_cache, const float *sin_cache,
76  int num_heads, int num_kv_heads, int num_tokens,
77  int head_dim, int aligned_head_dim, int pos_offset);
78 extern void rope_precompute_cache(float *cos_cache, float *sin_cache,
79  int max_seq_len, int head_dim, float base);
80 
81 /* SwiGLU kernel (from swiglu_kernels.c) */
82 extern void swiglu_forward(const float *input, float *output, int tokens, int dim);
83 
84 /* Attention kernels (from attention_kernels.c / attention_flash_true.c) */
86  const float *q, const float *k, const float *v, float *output,
87  int num_heads, int num_kv_heads, int num_tokens,
88  int head_dim, int aligned_head_dim, int kv_stride_tokens);
89 
90 /* Sliding-window attention kernels (from attention_kernels.c) */
92  const float *q, const float *k, const float *v, float *output,
93  int num_heads, int num_kv_heads, int num_tokens,
94  int head_dim, int aligned_head_dim, int kv_stride_tokens,
95  int sliding_window);
96 
98  const float *q_token, const float *k_cache, const float *v_cache,
99  float *out_token, int num_heads, int num_kv_heads,
100  int kv_tokens, int cache_capacity, int head_dim,
101  int aligned_head_dim, int sliding_window);
102 
103 /* GeGLU kernels (from gelu_kernels.c) */
104 extern void geglu_forward_fp32(const float *x, float *out, int tokens, int dim);
105 extern void geglu_backward_fp32(const float *x, const float *d_out, float *d_x,
106  int n_tokens, int dim);
107 
108 /* ============================================================================
109  * Dequantization Tests
110  * ============================================================================ */
111 
112 void ck_test_dequant_q4_k(const void *src, float *dst, int n)
113 {
114  dequant_q4_k_row(src, dst, (size_t)n);
115 }
116 
117 void ck_test_dequant_q6_k(const void *src, float *dst, int n)
118 {
119  dequant_q6_k_row(src, dst, (size_t)n);
120 }
121 
122 void ck_test_dequant_q4_0(const void *src, float *dst, int n)
123 {
124  dequant_q4_0_row(src, dst, (size_t)n);
125 }
126 
127 void ck_test_dequant_q5_1(const void *src, float *dst, int n)
128 {
129  dequant_q5_1_row(src, dst, (size_t)n);
130 }
131 
132 /* ============================================================================
133  * Quantization Tests
134  * ============================================================================ */
135 
136 void ck_test_quantize_q8_k(const float *src, void *dst, int n)
137 {
138  quantize_row_q8_k(src, dst, n);
139 }
140 
141 /* ============================================================================
142  * GEMV Tests
143  * ============================================================================ */
144 
145 void ck_test_gemv_q4_k(const void *weight_q4k,
146  const float *input_f32,
147  float *output,
148  int cols)
149 {
150  /* Allocate Q8_K buffer for quantized activations */
151  int n_blocks = cols / CK_QK_K;
152  block_q8_K *q8_data = (block_q8_K *)malloc(n_blocks * sizeof(block_q8_K));
153  if (!q8_data) {
154  *output = 0.0f;
155  return;
156  }
157 
158  /* Quantize input to Q8_K */
159  quantize_row_q8_k(input_f32, q8_data, cols);
160 
161  /* Compute dot product using GEMV with M=1 */
162  gemv_q4_k_q8_k(output, weight_q4k, q8_data, 1, cols);
163 
164  free(q8_data);
165 }
166 
167 void ck_test_gemv_q6_k(const void *weight_q6k,
168  const float *input_f32,
169  float *output,
170  int cols)
171 {
172  /* Q6_K GEMV is not yet implemented in CK - provide reference impl */
173  /* For now, dequantize and compute in FP32 */
174  float *weight_f32 = (float *)malloc(cols * sizeof(float));
175  if (!weight_f32) {
176  *output = 0.0f;
177  return;
178  }
179 
180  dequant_q6_k_row(weight_q6k, weight_f32, cols);
181 
182  /* Dot product in FP32 */
183  double sum = 0.0;
184  for (int i = 0; i < cols; i++) {
185  sum += (double)weight_f32[i] * (double)input_f32[i];
186  }
187  *output = (float)sum;
188 
189  free(weight_f32);
190 }
191 
192 void ck_test_gemv_q5_0(const void *weight_q5_0,
193  const float *input_f32,
194  float *output,
195  int rows, int cols)
196 {
197  /* Match llama.cpp's test_gemv_q5_0:
198  * 1. Quantize input to Q8_0 format
199  * 2. Use quantized dot product (vec_dot_q5_0_q8_0)
200  *
201  * This ensures parity with llama.cpp which always uses the
202  * quantized path, NOT the FP32 dequantization path.
203  */
204  int n_blocks = cols / CK_QK8_0;
205  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
206  if (!q8_data) {
207  for (int r = 0; r < rows; r++) output[r] = 0.0f;
208  return;
209  }
210 
211  /* Quantize input to Q8_0 */
212  quantize_row_q8_0(input_f32, q8_data, cols);
213 
214  /* Call the quantized GEMV kernel (same as ck_test_gemv_q5_0_q8_0) */
215  gemv_q5_0_q8_0(output, weight_q5_0, q8_data, rows, cols);
216 
217  free(q8_data);
218 }
219 
220 void ck_test_gemv_q8_0(const void *weight_q8_0,
221  const float *input_f32,
222  float *output,
223  int rows, int cols)
224 {
225  /* Match llama.cpp's test_gemv_q8_0:
226  * 1. Quantize input to Q8_0 format
227  * 2. Use quantized dot product (vec_dot_q8_0_q8_0)
228  *
229  * This ensures parity with llama.cpp which always uses the
230  * quantized path, NOT the FP32 dequantization path.
231  */
232  int n_blocks = cols / CK_QK8_0;
233  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
234  if (!q8_data) {
235  for (int r = 0; r < rows; r++) output[r] = 0.0f;
236  return;
237  }
238 
239  /* Quantize input to Q8_0 */
240  quantize_row_q8_0(input_f32, q8_data, cols);
241 
242  /* Call the quantized GEMV kernel (same as ck_test_gemv_q8_0_q8_0) */
243  gemv_q8_0_q8_0(output, weight_q8_0, q8_data, rows, cols);
244 
245  free(q8_data);
246 }
247 
248 void ck_test_gemv_q5_0_q8_0(const void *weight_q5_0,
249  const float *input_f32,
250  float *output,
251  int rows, int cols)
252 {
253  /* This matches llama.cpp's approach:
254  * 1. Quantize input to Q8_0 format
255  * 2. Use quantized dot product (integer math)
256  * 3. Scale at the end
257  */
258  int n_blocks = cols / CK_QK8_0;
259  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
260  if (!q8_data) {
261  for (int r = 0; r < rows; r++) output[r] = 0.0f;
262  return;
263  }
264 
265  /* Quantize input to Q8_0 */
266  quantize_row_q8_0(input_f32, q8_data, cols);
267 
268  /* Call the quantized GEMV kernel */
269  gemv_q5_0_q8_0(output, weight_q5_0, q8_data, rows, cols);
270 
271  free(q8_data);
272 }
273 
274 void ck_test_gemv_q8_0_q8_0(const void *weight_q8_0,
275  const float *input_f32,
276  float *output,
277  int rows, int cols)
278 {
279  /* This matches llama.cpp's approach:
280  * 1. Quantize input to Q8_0 format
281  * 2. Use quantized dot product (integer math)
282  * 3. Scale at the end
283  */
284  int n_blocks = cols / CK_QK8_0;
285  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
286  if (!q8_data) {
287  for (int r = 0; r < rows; r++) output[r] = 0.0f;
288  return;
289  }
290 
291  /* Quantize input to Q8_0 */
292  quantize_row_q8_0(input_f32, q8_data, cols);
293 
294  /* Call the quantized GEMV kernel */
295  gemv_q8_0_q8_0(output, weight_q8_0, q8_data, rows, cols);
296 
297  free(q8_data);
298 }
299 
300 /* Q5_K GEMV test - uses FP32 activations directly */
301 void ck_test_gemv_q5_k(const void *weight_q5_k,
302  const float *input_f32,
303  float *output,
304  int rows, int cols)
305 {
306  /*
307  * IMPORTANT: gemv_q5_k() expects raw FP32 activations, NOT pre-quantized Q8_K.
308  *
309  * This is different from gemv_q4_k_q8_k() and gemv_q5_0_q8_0() which are
310  * "quantized dot product" kernels that take block_q8_K or block_q8_0 input.
311  *
312  * WHY THIS IS ERROR-PRONE:
313  * When copying from ck_test_gemv_q5_0() (which calls gemv_q5_0_q8_0),
314  * it is natural to assume Q5_K also needs pre-quantization. But the
315  * function name tells you: gemv_q5_k() takes float*, while
316  * gemv_q5_0_q8_0() takes block_q8_0*. If the kernel name does not
317  * have "_q8_0" or "_q8_k" suffix, it expects FP32 input.
318  *
319  * PARITY NOTE:
320  * llama.cpp reference uses ggml_vec_dot_q5_K_q8_K which quantizes
321  * the input to Q8_K internally. Our FP32 path will have slightly
322  * different numerical results. Use tolerance ~1e-2 for comparison.
323  * To get exact parity, implement gemv_q5_k_q8_k() (quantized dot product).
324  */
325  for (int r = 0; r < rows; r++) {
326  gemv_q5_k(&output[r],
327  (const char *)weight_q5_k + r * (cols / CK_QK_K) * sizeof(block_q5_K),
328  input_f32, 1, cols);
329  }
330 }
331 
332 /* Q5_1 GEMV test - uses FP32 activations directly */
333 void ck_test_gemv_q5_1(const void *weight_q5_1,
334  const float *input_f32,
335  float *output,
336  int rows, int cols)
337 {
338  /*
339  * IMPORTANT: gemv_q5_1() expects raw FP32 activations, NOT pre-quantized Q8_0.
340  * See comment in ck_test_gemv_q5_k() above for explanation.
341  */
342  for (int r = 0; r < rows; r++) {
343  gemv_q5_1(&output[r],
344  (const char *)weight_q5_1 + r * (cols / QK5_1) * sizeof(block_q5_1),
345  input_f32, 1, cols);
346  }
347 }
348 
349 /* ============================================================================
350  * Direct Vec Dot Tests (pre-quantized inputs, no FP32 conversion)
351  * ============================================================================ */
352 
353 /**
354  * @brief Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
355  *
356  * This is a "direct" test that bypasses FP32-to-Q8_0 conversion.
357  * Useful for isolating kernel bugs from quantization bugs.
358  *
359  * @param weight_q5_0 Q5_0 quantized weights [cols]
360  * @param input_q8_0 Q8_0 quantized input [cols] (pre-quantized!)
361  * @param output Output scalar [1]
362  * @param cols Number of elements (must be multiple of 32)
363  */
364 void ck_test_vec_dot_q5_0_q8_0(const void *weight_q5_0,
365  const void *input_q8_0,
366  float *output,
367  int cols)
368 {
369  vec_dot_q5_0_q8_0(cols, output, weight_q5_0, input_q8_0);
370 }
371 
372 /**
373  * @brief Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
374  *
375  * @param weight_q8_0 Q8_0 quantized weights [cols]
376  * @param input_q8_0 Q8_0 quantized input [cols] (pre-quantized!)
377  * @param output Output scalar [1]
378  * @param cols Number of elements (must be multiple of 32)
379  */
380 void ck_test_vec_dot_q8_0_q8_0(const void *weight_q8_0,
381  const void *input_q8_0,
382  float *output,
383  int cols)
384 {
385  vec_dot_q8_0_q8_0(cols, output, weight_q8_0, input_q8_0);
386 }
387 
388 /* ============================================================================
389  * GEMM Tests
390  * ============================================================================ */
391 
392 void ck_test_gemm_q4_k(const void *weight_q4k,
393  const float *input_f32,
394  float *output,
395  int rows, int cols, int n_tokens)
396 {
397  /* Allocate Q8_K buffer for quantized activations */
398  int n_blocks_per_row = cols / CK_QK_K;
399  block_q8_K *q8_data = (block_q8_K *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_K));
400  if (!q8_data) {
401  memset(output, 0, n_tokens * rows * sizeof(float));
402  return;
403  }
404 
405  /* Quantize all input tokens */
406  for (int t = 0; t < n_tokens; t++) {
407  quantize_row_q8_k(input_f32 + t * cols,
408  q8_data + t * n_blocks_per_row, cols);
409  }
410 
411  /* Use gemm_nt_q4_k_q8_k: C[M,N] = A[M,K] * B[N,K]^T
412  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
413  * So: M = n_tokens, N = rows, K = cols
414  */
415  gemm_nt_q4_k_q8_k(q8_data, weight_q4k, NULL, output, n_tokens, rows, cols);
416 
417  free(q8_data);
418 }
419 
420 /**
421  * @brief Test Q6_K x Q8_K GEMM (batch matrix multiply)
422  *
423  * Used for MLP W2 (down projection) with Q6_K weights.
424  */
425 void ck_test_gemm_q6_k(const void *weight_q6k,
426  const float *input_f32,
427  float *output,
428  int rows, int cols, int n_tokens)
429 {
430  /* Allocate Q8_K buffer for quantized activations */
431  int n_blocks_per_row = cols / CK_QK_K;
432  block_q8_K *q8_data = (block_q8_K *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_K));
433  if (!q8_data) {
434  memset(output, 0, n_tokens * rows * sizeof(float));
435  return;
436  }
437 
438  /* Quantize all input tokens */
439  for (int t = 0; t < n_tokens; t++) {
440  quantize_row_q8_k(input_f32 + t * cols,
441  q8_data + t * n_blocks_per_row, cols);
442  }
443 
444  /* Use gemm_nt_q6_k_q8_k: C[M,N] = A[M,K] * B[N,K]^T
445  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
446  * So: M = n_tokens, N = rows, K = cols
447  */
448  gemm_nt_q6_k_q8_k(q8_data, weight_q6k, NULL, output, n_tokens, rows, cols);
449 
450  free(q8_data);
451 }
452 
453 /**
454  * @brief Test Q8_0 x Q8_0 GEMM (batch matrix multiply)
455  *
456  * Used for attention V projection with Q8_0 weights.
457  */
458 void ck_test_gemm_q8_0(const void *weight_q8_0,
459  const float *input_f32,
460  float *output,
461  int rows, int cols, int n_tokens)
462 {
463  /* Allocate Q8_0 buffer for quantized activations */
464  int n_blocks_per_row = cols / CK_QK8_0;
465  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_0));
466  if (!q8_data) {
467  memset(output, 0, n_tokens * rows * sizeof(float));
468  return;
469  }
470 
471  /* Quantize all input tokens */
472  for (int t = 0; t < n_tokens; t++) {
473  quantize_row_q8_0(input_f32 + t * cols,
474  q8_data + t * n_blocks_per_row, cols);
475  }
476 
477  /* Use gemm_nt_q8_0_q8_0: C[M,N] = A[M,K] * B[N,K]^T
478  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
479  * So: M = n_tokens, N = rows, K = cols
480  */
481  gemm_nt_q8_0_q8_0(q8_data, weight_q8_0, NULL, output, n_tokens, rows, cols);
482 
483  free(q8_data);
484 }
485 
486 /**
487  * @brief Test Q5_0 x Q8_0 GEMM (batch matrix multiply)
488  *
489  * Used for MLP W1 (gate/up projection) and attention Q/K with Q5_0 weights.
490  */
491 void ck_test_gemm_q5_0(const void *weight_q5_0,
492  const float *input_f32,
493  float *output,
494  int rows, int cols, int n_tokens)
495 {
496  /* Allocate Q8_0 buffer for quantized activations */
497  int n_blocks_per_row = cols / CK_QK8_0;
498  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_0));
499  if (!q8_data) {
500  memset(output, 0, n_tokens * rows * sizeof(float));
501  return;
502  }
503 
504  /* Quantize all input tokens */
505  for (int t = 0; t < n_tokens; t++) {
506  quantize_row_q8_0(input_f32 + t * cols,
507  q8_data + t * n_blocks_per_row, cols);
508  }
509 
510  /* Use gemm_nt_q5_0_q8_0: C[M,N] = A[M,K] * B[N,K]^T
511  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
512  * So: M = n_tokens, N = rows, K = cols
513  */
514  gemm_nt_q5_0_q8_0(q8_data, weight_q5_0, NULL, output, n_tokens, rows, cols);
515 
516  free(q8_data);
517 }
518 
519 /**
520  * @brief Test Q5_K x Q8_K GEMM (batch matrix multiply)
521  *
522  * Used for MLP W1 (gate/up projection) and attention Q/K with Q5_K weights.
523  * gemm_nt_q5_k expects FP32 activations (not quantized).
524  */
525 void ck_test_gemm_q5_k(const void *weight_q5_k,
526  const float *input_f32,
527  float *output,
528  int rows, int cols, int n_tokens)
529 {
530  /* gemm_nt_q5_k expects FP32 activations, not quantized.
531  * Pass input_f32 directly as-is (already FP32).
532  */
533  gemm_nt_q5_k(input_f32, weight_q5_k, NULL, output, n_tokens, rows, cols);
534 }
535 
536 /**
537  * @brief Test Q5_1 x Q8_0 GEMM (batch matrix multiply)
538  *
539  * Used for MLP W1 (gate/up projection) and attention Q/K with Q5_1 weights.
540  * gemm_nt_q5_1 expects FP32 activations (not quantized).
541  */
542 void ck_test_gemm_q5_1(const void *weight_q5_1,
543  const float *input_f32,
544  float *output,
545  int rows, int cols, int n_tokens)
546 {
547  /* gemm_nt_q5_1 expects FP32 activations, not quantized.
548  * Pass input_f32 directly as-is (already FP32).
549  */
550  gemm_nt_q5_1(input_f32, weight_q5_1, NULL, output, n_tokens, rows, cols);
551 }
552 
553 /* ============================================================================
554  * Activation Kernels
555  * ============================================================================ */
556 
557 void ck_test_rmsnorm(const float *input,
558  const float *weight,
559  float *output,
560  int n_tokens, int dim, float eps)
561 {
562  /* CK rmsnorm_forward has aligned_embed_dim parameter
563  * For testing, use dim as aligned_embed_dim (no padding) */
564  rmsnorm_forward(input, weight, output, NULL, n_tokens, dim, dim, eps);
565 }
566 
567 void ck_test_rope(float *q, float *k,
568  int n_tokens, int n_heads, int n_heads_kv, int head_dim,
569  int pos_offset, float theta)
570 {
571  /* Precompute cos/sin cache */
572  int half_dim = head_dim / 2;
573  int max_seq = pos_offset + n_tokens;
574 
575  float *cos_cache = (float *)malloc(max_seq * half_dim * sizeof(float));
576  float *sin_cache = (float *)malloc(max_seq * half_dim * sizeof(float));
577  if (!cos_cache || !sin_cache) {
578  free(cos_cache);
579  free(sin_cache);
580  return;
581  }
582 
583  rope_precompute_cache(cos_cache, sin_cache, max_seq, head_dim, theta);
584 
585  /* CK RoPE expects layout [num_heads, num_tokens, head_dim]
586  * Reshape from [n_tokens, n_heads * head_dim] to [n_heads, n_tokens, head_dim]
587  */
588  float *q_reorder = (float *)malloc(n_heads * n_tokens * head_dim * sizeof(float));
589  float *k_reorder = (float *)malloc(n_heads_kv * n_tokens * head_dim * sizeof(float));
590 
591  if (q_reorder && k_reorder) {
592  /* Reorder Q: [T, H*D] -> [H, T, D] */
593  for (int t = 0; t < n_tokens; t++) {
594  for (int h = 0; h < n_heads; h++) {
595  for (int d = 0; d < head_dim; d++) {
596  q_reorder[h * n_tokens * head_dim + t * head_dim + d] =
597  q[t * n_heads * head_dim + h * head_dim + d];
598  }
599  }
600  }
601 
602  /* Reorder K: [T, H_kv*D] -> [H_kv, T, D] */
603  for (int t = 0; t < n_tokens; t++) {
604  for (int h = 0; h < n_heads_kv; h++) {
605  for (int d = 0; d < head_dim; d++) {
606  k_reorder[h * n_tokens * head_dim + t * head_dim + d] =
607  k[t * n_heads_kv * head_dim + h * head_dim + d];
608  }
609  }
610  }
611 
612  /* Apply RoPE */
613  rope_forward_qk(q_reorder, k_reorder,
614  cos_cache, sin_cache,
615  n_heads, n_heads_kv, n_tokens,
616  head_dim, head_dim, pos_offset);
617 
618  /* Reorder back: [H, T, D] -> [T, H*D] */
619  for (int t = 0; t < n_tokens; t++) {
620  for (int h = 0; h < n_heads; h++) {
621  for (int d = 0; d < head_dim; d++) {
622  q[t * n_heads * head_dim + h * head_dim + d] =
623  q_reorder[h * n_tokens * head_dim + t * head_dim + d];
624  }
625  }
626  }
627 
628  for (int t = 0; t < n_tokens; t++) {
629  for (int h = 0; h < n_heads_kv; h++) {
630  for (int d = 0; d < head_dim; d++) {
631  k[t * n_heads_kv * head_dim + h * head_dim + d] =
632  k_reorder[h * n_tokens * head_dim + t * head_dim + d];
633  }
634  }
635  }
636  }
637 
638  free(q_reorder);
639  free(k_reorder);
640  free(cos_cache);
641  free(sin_cache);
642 }
643 
644 void ck_test_rope_interleaved(float *q, float *k,
645  int n_tokens, int n_heads, int n_heads_kv, int head_dim,
646  int pos_offset, float theta)
647 {
648  /* Interleaved RoPE format (matches llama.cpp):
649  * (x0, x1) -> (x0*cos - x1*sin, x0*sin + x1*cos)
650  * Applied to consecutive pairs of elements
651  */
652 
653  /* Precompute inverse frequencies */
654  float *inv_freq = (float *)malloc((head_dim / 2) * sizeof(float));
655  if (!inv_freq) return;
656 
657  for (int i = 0; i < head_dim / 2; i++) {
658  inv_freq[i] = 1.0f / powf(theta, (float)(2 * i) / head_dim);
659  }
660 
661  /* Apply RoPE to Q */
662  for (int t = 0; t < n_tokens; t++) {
663  int pos = pos_offset + t;
664  for (int h = 0; h < n_heads; h++) {
665  float *qh = q + t * n_heads * head_dim + h * head_dim;
666 
667  for (int i = 0; i < head_dim / 2; i++) {
668  float freq = pos * inv_freq[i];
669  float cos_val = cosf(freq);
670  float sin_val = sinf(freq);
671 
672  /* Interleaved format */
673  float x0 = qh[i * 2];
674  float x1 = qh[i * 2 + 1];
675  qh[i * 2] = x0 * cos_val - x1 * sin_val;
676  qh[i * 2 + 1] = x0 * sin_val + x1 * cos_val;
677  }
678  }
679  }
680 
681  /* Apply RoPE to K */
682  for (int t = 0; t < n_tokens; t++) {
683  int pos = pos_offset + t;
684  for (int h = 0; h < n_heads_kv; h++) {
685  float *kh = k + t * n_heads_kv * head_dim + h * head_dim;
686 
687  for (int i = 0; i < head_dim / 2; i++) {
688  float freq = pos * inv_freq[i];
689  float cos_val = cosf(freq);
690  float sin_val = sinf(freq);
691 
692  float x0 = kh[i * 2];
693  float x1 = kh[i * 2 + 1];
694  kh[i * 2] = x0 * cos_val - x1 * sin_val;
695  kh[i * 2 + 1] = x0 * sin_val + x1 * cos_val;
696  }
697  }
698  }
699 
700  free(inv_freq);
701 }
702 
703 void ck_test_swiglu(const float *gate_up,
704  float *output,
705  int n_tokens, int intermediate_dim)
706 {
707  swiglu_forward(gate_up, output, n_tokens, intermediate_dim);
708 }
709 
710 void ck_test_softmax(const float *input, float *output, int n)
711 {
712  /* Find max for numerical stability */
713  float max_val = input[0];
714  for (int i = 1; i < n; i++) {
715  if (input[i] > max_val) max_val = input[i];
716  }
717 
718  /* Compute exp and sum */
719  float sum = 0.0f;
720  for (int i = 0; i < n; i++) {
721  output[i] = expf(input[i] - max_val);
722  sum += output[i];
723  }
724 
725  /* Normalize */
726  float inv_sum = 1.0f / sum;
727  for (int i = 0; i < n; i++) {
728  output[i] *= inv_sum;
729  }
730 }
731 
732 /* ============================================================================
733  * Attention Kernels
734  * ============================================================================ */
735 
736 void ck_test_attention_causal(const float *q,
737  const float *k,
738  const float *v,
739  float *out,
740  int num_heads,
741  int num_kv_heads,
742  int tokens,
743  int seq_len,
744  int head_dim)
745 {
746  /* For prefill, seq_len == tokens, and kv_stride == tokens.
747  * The CK kernel expects strided KV layout with kv_stride_tokens parameter.
748  * For parity testing with contiguous tensors, kv_stride = seq_len.
749  */
751  q, k, v, out,
752  num_heads, num_kv_heads, tokens,
753  head_dim, head_dim, /* aligned_head_dim = head_dim for testing */
754  seq_len /* kv_stride_tokens = seq_len for contiguous KV */
755  );
756 }
757 
758 /**
759  * @brief Test sliding-window attention (prefill)
760  *
761  * Layout (head-major, matching CK-Engine):
762  * Q: [num_heads, tokens, head_dim]
763  * K: [num_kv_heads, seq_len, head_dim]
764  * V: [num_kv_heads, seq_len, head_dim]
765  * out: [num_heads, tokens, head_dim]
766  *
767  * Each token attends only to the last `sliding_window` tokens.
768  */
770  const float *k,
771  const float *v,
772  float *out,
773  int num_heads,
774  int num_kv_heads,
775  int tokens,
776  int seq_len,
777  int head_dim,
778  int sliding_window)
779 {
781  q, k, v, out,
782  num_heads, num_kv_heads, tokens,
783  head_dim, head_dim, /* aligned_head_dim = head_dim for testing */
784  seq_len, /* kv_stride_tokens = seq_len for contiguous KV */
785  sliding_window
786  );
787 }
788 
789 /**
790  * @brief Test sliding-window attention (decode mode)
791  *
792  * Single query token attending to KV cache with sliding window.
793  */
794 void ck_test_attention_decode_sliding(const float *q_token,
795  const float *k_cache,
796  const float *v_cache,
797  float *out_token,
798  int num_heads,
799  int num_kv_heads,
800  int kv_tokens,
801  int cache_capacity,
802  int head_dim,
803  int sliding_window)
804 {
806  q_token, k_cache, v_cache, out_token,
807  num_heads, num_kv_heads,
808  kv_tokens, cache_capacity, head_dim, head_dim,
809  sliding_window
810  );
811 }
812 
813 /**
814  * @brief Test GeGLU activation
815  *
816  * Computes: output = GELU(a) * b
817  * where input contains [a, b] concatenated along the last dimension.
818  */
819 void ck_test_geglu(const float *x,
820  float *out,
821  int n_tokens,
822  int dim)
823 {
824  geglu_forward_fp32(x, out, n_tokens, dim);
825 }
826 
827 /**
828  * @brief Test GeGLU backward
829  *
830  * Computes gradients dL/dx given dL/d(out) where out = GELU(a) * b
831  */
832 void ck_test_geglu_backward(const float *x,
833  const float *d_out,
834  float *d_x,
835  int n_tokens,
836  int dim)
837 {
838  geglu_backward_fp32(x, d_out, d_x, n_tokens, dim);
839 }
840 
841 /* ============================================================================
842  * Mega-Fused OutProj + MLP Kernels
843  * ============================================================================ */
844 
845 /* External declaration for mega_fused_outproj_mlp_prefill */
847  float *output,
848  const float *attn_out,
849  const float *residual,
850  const float *ln2_gamma,
851  const void *wo, const float *bo, int wo_dt,
852  const void *w1, const float *b1, int w1_dt,
853  const void *w2, const float *b2, int w2_dt,
854  int tokens,
855  int embed_dim,
856  int aligned_embed_dim,
857  int num_heads,
858  int aligned_head_dim,
859  int intermediate_dim,
860  int aligned_intermediate_dim,
861  float eps,
862  void *scratch);
863 
865  int tokens,
866  int aligned_embed_dim,
867  int num_heads,
868  int aligned_head_dim,
869  int aligned_intermediate_dim);
870 
871 /**
872  * @brief Test mega-fused OutProj + MLP kernel (Q5_0 weights)
873  *
874  * This is a simplified wrapper for parity testing that:
875  * - Uses Q5_0 for W_o and W1 weights
876  * - Uses Q4_K for W2 weights
877  * - Allocates scratch internally
878  *
879  * @param attn_out Attention output [num_heads, tokens, head_dim] (FP32, head-major)
880  * @param residual Residual input [tokens, embed_dim] (FP32)
881  * @param ln2_gamma RMSNorm gamma [embed_dim] (FP32)
882  * @param wo OutProj weights [embed_dim, embed_dim] (Q5_0)
883  * @param w1 MLP W1 weights [2*intermediate, embed_dim] (Q5_0)
884  * @param w2 MLP W2 weights [embed_dim, intermediate] (Q4_K or Q6_K)
885  * @param output Output [tokens, embed_dim] (FP32)
886  * @param tokens Number of tokens
887  * @param num_heads Number of attention heads
888  * @param head_dim Dimension per head
889  * @param embed_dim Embedding dimension (= num_heads * head_dim)
890  * @param intermediate MLP intermediate dimension
891  * @param eps RMSNorm epsilon
892  * @param w2_is_q6k If true, W2 is Q6_K; if false, W2 is Q4_K
893  */
895  const float *attn_out,
896  const float *residual,
897  const float *ln2_gamma,
898  const void *wo,
899  const void *w1,
900  const void *w2,
901  float *output,
902  int tokens,
903  int num_heads,
904  int head_dim,
905  int embed_dim,
906  int intermediate,
907  float eps,
908  int w2_is_q6k)
909 {
910  /* CK uses dtype enum: CK_DT_Q5_0 = 11, CK_DT_Q4_K = 7, CK_DT_Q6_K = 8 */
911  const int CK_DT_Q5_0_VAL = 11;
912  const int CK_DT_Q4_K_VAL = 7;
913  const int CK_DT_Q6_K_VAL = 8;
914 
915  /* For parity testing, aligned = actual (no padding) */
916  int aligned_embed_dim = embed_dim;
917  int aligned_head_dim = head_dim;
918  int aligned_intermediate = intermediate;
919 
920  /* Ensure intermediate is multiple of 256 (QK_K) for K-quants */
921  if ((intermediate % 256) != 0) {
922  aligned_intermediate = ((intermediate + 255) / 256) * 256;
923  }
924 
925  /* Allocate scratch */
926  size_t scratch_size = mega_fused_outproj_mlp_prefill_scratch_size(
927  tokens, aligned_embed_dim, num_heads, aligned_head_dim, aligned_intermediate);
928 
929  void *scratch = malloc(scratch_size);
930  if (!scratch) {
931  return;
932  }
933 
934  /* Call the mega-fused kernel */
936  output,
937  attn_out,
938  residual,
939  ln2_gamma,
940  wo, NULL, CK_DT_Q5_0_VAL, /* W_o with Q5_0 */
941  w1, NULL, CK_DT_Q5_0_VAL, /* W1 with Q5_0 */
942  w2, NULL, w2_is_q6k ? CK_DT_Q6_K_VAL : CK_DT_Q4_K_VAL, /* W2 with Q4_K or Q6_K */
943  tokens,
944  embed_dim,
945  aligned_embed_dim,
946  num_heads,
947  aligned_head_dim,
948  intermediate,
949  aligned_intermediate,
950  eps,
951  scratch
952  );
953 
954  free(scratch);
955 }
956 
957 /* ============================================================================
958  * Utility Functions
959  * ============================================================================ */
960 
962 {
963  return sizeof(block_q4_K);
964 }
965 
967 {
968  return sizeof(block_q6_K);
969 }
970 
972 {
973  return sizeof(block_q8_K);
974 }
975 
976 int ck_get_qk_k(void)
977 {
978  return QK_K;
979 }
980 
982 {
983  return sizeof(block_q5_K);
984 }
985 
987 {
988  return sizeof(block_q5_1);
989 }
990 
991 int ck_get_qk5_1(void)
992 {
993  return QK5_1;
994 }
void ck_test_quantize_q8_k(const float *src, void *dst, int n)
Quantize FP32 to Q8_K (for activations)
void dequant_q4_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_0 row (multiple blocks)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void ck_test_gemm_q4_k(const void *weight_q4k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q4_K GEMM - batched matrix multiply with quantized weights.
int ck_get_block_q5_k_size(void)
Get Q5_K block size in bytes (176 bytes per 256 weights)
int ck_get_block_q8_k_size(void)
Get Q8_K block size in bytes.
void ck_test_dequant_q6_k(const void *src, float *dst, int n)
Dequantize Q6_K data to FP32.
void ck_test_rope(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE (Rotary Position Embedding)
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void swiglu_forward(const float *input, float *output, int tokens, int dim)
int ck_get_qk5_1(void)
Get QK5_1 (elements per Q5_1 block)
void ck_test_gemv_q5_0_q8_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
Definition: rope_kernels.c:52
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
void ck_test_dequant_q4_0(const void *src, float *dst, int n)
Dequantize Q4_0 data to FP32.
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
void ck_test_softmax(const float *input, float *output, int n)
Softmax (simple, non-causal)
void ck_test_rmsnorm(const float *input, const float *weight, float *output, int n_tokens, int dim, float eps)
RMSNorm.
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void ck_test_gemm_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q5_0 x Q8_0 GEMM (batch matrix multiply)
void ck_test_dequant_q4_k(const void *src, float *dst, int n)
Dequantize Q4_K data to FP32.
void ck_test_attention_causal(const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)
Multi-head causal attention for prefill (head-major layout)
void ck_test_gemv_q4_k(const void *weight_q4k, const float *input_f32, float *output, int cols)
Q4_K GEMV - dot product of quantized weights and FP32 input.
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, int wo_dt, const void *w1, const float *b1, int w1_dt, const void *w2, const float *b2, int w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
void quantize_row_q8_k(const float *x, void *vy, int k)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void ck_test_gemv_q8_0_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void ck_test_vec_dot_q8_0_q8_0(const void *weight_q8_0, const void *input_q8_0, float *output, int cols)
Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
void ck_test_geglu(const float *x, float *out, int n_tokens, int dim)
Test GeGLU activation.
void ck_test_swiglu(const float *gate_up, float *output, int n_tokens, int intermediate_dim)
SwiGLU activation.
void ck_test_gemm_q6_k(const void *weight_q6k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q6_K x Q8_K GEMM (batch matrix multiply)
void ck_test_vec_dot_q5_0_q8_0(const void *weight_q5_0, const void *input_q8_0, float *output, int cols)
Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
void ck_test_gemm_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q5_1 x Q8_0 GEMM (batch matrix multiply)
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
Definition: gelu_kernels.c:623
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
int ck_get_qk_k(void)
Get QK_K (elements per super-block)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void ck_test_gemm_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q8_0 x Q8_0 GEMM (batch matrix multiply)
void ck_test_gemv_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols)
Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B_q8, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void ck_test_geglu_backward(const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
Test GeGLU backward.
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
void ck_test_attention_sliding_window(const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim, int sliding_window)
Test sliding-window attention (prefill)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
void ck_test_gemv_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.
void dequant_q5_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_1 row (multiple blocks)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
void ck_test_gemv_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols)
Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void ck_test_attention_decode_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int sliding_window)
Test sliding-window attention (decode mode)
void ck_test_rope_interleaved(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE with interleaved format (for llama.cpp compatibility)
void ck_test_dequant_q5_1(const void *src, float *dst, int n)
Dequantize Q5_1 data to FP32.
int ck_get_block_q4_k_size(void)
Get Q4_K block size in bytes.
int ck_get_block_q6_k_size(void)
Get Q6_K block size in bytes.
void ck_test_gemv_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
void ck_test_gemm_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q5_K x Q8_K GEMM (batch matrix multiply)
void ck_test_gemv_q6_k(const void *weight_q6k, const float *input_f32, float *output, int cols)
Q6_K GEMV.
void ck_test_outproj_mlp_fused_q5_0(const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const void *w1, const void *w2, float *output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)
Test mega-fused OutProj + MLP kernel (Q5_0 weights)
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
Definition: gelu_kernels.c:843
int ck_get_block_q5_1_size(void)
Get Q5_1 block size in bytes (24 bytes per 32 weights)
C-Kernel-Engine Parity Testing API.
#define CK_QK_K
Definition: ck_parity_api.h:28
#define CK_QK8_0
Definition: ck_parity_api.h:30
Quantization block structures for weight-only quantization.
#define QK5_1
Definition: ckernel_quant.h:84
#define QK_K
#define C(color)
Definition: show_config.c:39