← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ck_parity_api.h
Go to the documentation of this file.
1 /**
2  * @file ck_parity_api.h
3  * @brief C-Kernel-Engine Parity Testing API
4  *
5  * Exposes individual CK kernels for parity testing against llama.cpp/ggml.
6  * This API mirrors the test-kernel-parity.cpp interface in llama.cpp.
7  *
8  * Usage:
9  * 1. Build as shared library: libck_parity.so
10  * 2. Load from Python using ctypes
11  * 3. Call functions with matching signatures to test-kernel-parity.cpp
12  */
13 
14 #ifndef CK_PARITY_API_H
15 #define CK_PARITY_API_H
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #ifdef __cplusplus
21 extern "C" {
22 #endif
23 
24 /* ============================================================================
25  * Constants (must match llama.cpp/ggml)
26  * ============================================================================ */
27 
28 #define CK_QK_K 256 /* Elements per K-quant super-block */
29 #define CK_QK4_0 32 /* Elements per Q4_0 block */
30 #define CK_QK8_0 32 /* Elements per Q8_0 block */
31 
32 /* Block sizes in bytes */
33 #define CK_BLOCK_Q4_K_SIZE 144
34 #define CK_BLOCK_Q6_K_SIZE 210
35 #define CK_BLOCK_Q8_K_SIZE 292
36 #define CK_BLOCK_Q4_0_SIZE 18
37 #define CK_BLOCK_Q5_K_SIZE 176
38 #define CK_BLOCK_Q5_1_SIZE 24
39 
40 /* ============================================================================
41  * Dequantization Tests
42  * ============================================================================ */
43 
44 /**
45  * @brief Dequantize Q4_K data to FP32
46  * @param src Input Q4_K blocks
47  * @param dst Output FP32 values
48  * @param n Number of elements (must be multiple of 256)
49  */
50 void ck_test_dequant_q4_k(const void *src, float *dst, int n);
51 
52 /**
53  * @brief Dequantize Q6_K data to FP32
54  */
55 void ck_test_dequant_q6_k(const void *src, float *dst, int n);
56 
57 /**
58  * @brief Dequantize Q4_0 data to FP32
59  */
60 void ck_test_dequant_q4_0(const void *src, float *dst, int n);
61 
62 /**
63  * @brief Dequantize Q5_1 data to FP32
64  */
65 void ck_test_dequant_q5_1(const void *src, float *dst, int n);
66 
67 /* ============================================================================
68  * Quantization Tests
69  * ============================================================================ */
70 
71 /**
72  * @brief Quantize FP32 to Q8_K (for activations)
73  * @param src Input FP32 values
74  * @param dst Output Q8_K blocks
75  * @param n Number of elements (must be multiple of 256)
76  */
77 void ck_test_quantize_q8_k(const float *src, void *dst, int n);
78 
79 /* ============================================================================
80  * GEMV (Matrix-Vector) Tests
81  * ============================================================================ */
82 
83 /**
84  * @brief Q4_K GEMV - dot product of quantized weights and FP32 input
85  *
86  * Internally quantizes input to Q8_K, then computes dot product.
87  *
88  * @param weight_q4k Q4_K quantized weights [cols]
89  * @param input_f32 FP32 input vector [cols]
90  * @param output Output scalar [1]
91  * @param cols Number of columns (must be multiple of 256)
92  */
93 void ck_test_gemv_q4_k(const void *weight_q4k,
94  const float *input_f32,
95  float *output,
96  int cols);
97 
98 /**
99  * @brief Q6_K GEMV
100  */
101 void ck_test_gemv_q6_k(const void *weight_q6k,
102  const float *input_f32,
103  float *output,
104  int cols);
105 
106 /**
107  * @brief Q5_0 GEMV - matrix-vector multiply with Q5_0 weights
108  *
109  * @param weight_q5_0 Q5_0 quantized weights [rows * cols]
110  * @param input_f32 FP32 input vector [cols]
111  * @param output FP32 output vector [rows]
112  * @param rows Number of output rows
113  * @param cols Number of columns (must be multiple of 32)
114  */
115 void ck_test_gemv_q5_0(const void *weight_q5_0,
116  const float *input_f32,
117  float *output,
118  int rows, int cols);
119 
120 /**
121  * @brief Q8_0 GEMV - matrix-vector multiply with Q8_0 weights
122  *
123  * @param weight_q8_0 Q8_0 quantized weights [rows * cols]
124  * @param input_f32 FP32 input vector [cols]
125  * @param output FP32 output vector [rows]
126  * @param rows Number of output rows
127  * @param cols Number of columns (must be multiple of 32)
128  */
129 void ck_test_gemv_q8_0(const void *weight_q8_0,
130  const float *input_f32,
131  float *output,
132  int rows, int cols);
133 
134 /**
135  * @brief Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach
136  *
137  * This version quantizes the input to Q8_0 first, then uses integer
138  * dot products (like llama.cpp does). Use this for parity testing.
139  *
140  * @param weight_q5_0 Q5_0 quantized weights [rows * cols]
141  * @param input_f32 FP32 input vector [cols] - will be quantized to Q8_0
142  * @param output FP32 output vector [rows]
143  * @param rows Number of output rows
144  * @param cols Number of columns (must be multiple of 32)
145  */
146 void ck_test_gemv_q5_0_q8_0(const void *weight_q5_0,
147  const float *input_f32,
148  float *output,
149  int rows, int cols);
150 
151 /**
152  * @brief Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach
153  *
154  * This version quantizes the input to Q8_0 first, then uses integer
155  * dot products (like llama.cpp does). Use this for parity testing.
156  *
157  * @param weight_q8_0 Q8_0 quantized weights [rows * cols]
158  * @param input_f32 FP32 input vector [cols] - will be quantized to Q8_0
159  * @param output FP32 output vector [rows]
160  * @param rows Number of output rows
161  * @param cols Number of columns (must be multiple of 32)
162  */
163 void ck_test_gemv_q8_0_q8_0(const void *weight_q8_0,
164  const float *input_f32,
165  float *output,
166  int rows, int cols);
167 
168 /**
169  * @brief Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
170  *
171  * Uses Q8_K for activations (like Q4_K).
172  *
173  * @param weight_q5_k Q5_K quantized weights [rows * cols]
174  * @param input_f32 FP32 input vector [cols]
175  * @param output FP32 output vector [rows]
176  * @param rows Number of output rows
177  * @param cols Number of columns (must be multiple of 256)
178  */
179 void ck_test_gemv_q5_k(const void *weight_q5_k,
180  const float *input_f32,
181  float *output,
182  int rows, int cols);
183 
184 /**
185  * @brief Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
186  *
187  * Uses Q8_0 for activations (like Q5_0).
188  *
189  * @param weight_q5_1 Q5_1 quantized weights [rows * cols]
190  * @param input_f32 FP32 input vector [cols]
191  * @param output FP32 output vector [rows]
192  * @param rows Number of output rows
193  * @param cols Number of columns (must be multiple of 32)
194  */
195 void ck_test_gemv_q5_1(const void *weight_q5_1,
196  const float *input_f32,
197  float *output,
198  int rows, int cols);
199 
200 /* ============================================================================
201  * Direct Vec Dot Tests (pre-quantized inputs, no FP32 conversion)
202  * ============================================================================ */
203 
204 /**
205  * @brief Direct Q5_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
206  *
207  * This is a "direct" test that bypasses FP32-to-Q8_0 conversion.
208  * Useful for isolating kernel bugs from quantization bugs.
209  *
210  * @param weight_q5_0 Q5_0 quantized weights [cols]
211  * @param input_q8_0 Q8_0 quantized input [cols] (pre-quantized!)
212  * @param output Output scalar [1]
213  * @param cols Number of elements (must be multiple of 32)
214  */
215 void ck_test_vec_dot_q5_0_q8_0(const void *weight_q5_0,
216  const void *input_q8_0,
217  float *output,
218  int cols);
219 
220 /**
221  * @brief Direct Q8_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
222  *
223  * @param weight_q8_0 Q8_0 quantized weights [cols]
224  * @param input_q8_0 Q8_0 quantized input [cols] (pre-quantized!)
225  * @param output Output scalar [1]
226  * @param cols Number of elements (must be multiple of 32)
227  */
228 void ck_test_vec_dot_q8_0_q8_0(const void *weight_q8_0,
229  const void *input_q8_0,
230  float *output,
231  int cols);
232 
233 /* ============================================================================
234  * GEMM (Matrix-Matrix) Tests
235  * ============================================================================ */
236 
237 /**
238  * @brief Q4_K GEMM - batched matrix multiply with quantized weights
239  *
240  * Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
241  *
242  * @param weight_q4k Q4_K quantized weights [rows, cols]
243  * @param input_f32 FP32 input [n_tokens, cols]
244  * @param output FP32 output [n_tokens, rows]
245  * @param rows Number of output rows
246  * @param cols Number of columns (must be multiple of 256)
247  * @param n_tokens Batch size
248  */
249 void ck_test_gemm_q4_k(const void *weight_q4k,
250  const float *input_f32,
251  float *output,
252  int rows, int cols, int n_tokens);
253 
254 /**
255  * @brief Q6_K GEMM - batched matrix multiply with Q6_K weights
256  *
257  * Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
258  *
259  * @param weight_q6k Q6_K quantized weights [rows, cols]
260  * @param input_f32 FP32 input [n_tokens, cols]
261  * @param output FP32 output [n_tokens, rows]
262  * @param rows Number of output rows
263  * @param cols Number of columns (must be multiple of 256)
264  * @param n_tokens Batch size
265  */
266 void ck_test_gemm_q6_k(const void *weight_q6k,
267  const float *input_f32,
268  float *output,
269  int rows, int cols, int n_tokens);
270 
271 /**
272  * @brief Q5_0 GEMM - batched matrix multiply with Q5_0 weights (32-element blocks)
273  *
274  * Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
275  *
276  * @param weight_q5_0 Q5_0 quantized weights [rows, cols]
277  * @param input_f32 FP32 input [n_tokens, cols]
278  * @param output FP32 output [n_tokens, rows]
279  * @param rows Number of output rows
280  * @param cols Number of columns (must be multiple of 32)
281  * @param n_tokens Batch size
282  */
283 void ck_test_gemm_q5_0(const void *weight_q5_0,
284  const float *input_f32,
285  float *output,
286  int rows, int cols, int n_tokens);
287 
288 /**
289  * @brief Q8_0 GEMM - batched matrix multiply with Q8_0 weights (32-element blocks)
290  *
291  * Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
292  *
293  * @param weight_q8_0 Q8_0 quantized weights [rows, cols]
294  * @param input_f32 FP32 input [n_tokens, cols]
295  * @param output FP32 output [n_tokens, rows]
296  * @param rows Number of output rows
297  * @param cols Number of columns (must be multiple of 32)
298  * @param n_tokens Batch size
299  */
300 void ck_test_gemm_q8_0(const void *weight_q8_0,
301  const float *input_f32,
302  float *output,
303  int rows, int cols, int n_tokens);
304 
305 /**
306  * @brief Q5_K GEMM - batched matrix multiply with Q5_K weights (256-element super-blocks)
307  *
308  * Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
309  * Uses Q8_K for activations.
310  *
311  * @param weight_q5_k Q5_K quantized weights [rows, cols]
312  * @param input_f32 FP32 input [n_tokens, cols]
313  * @param output FP32 output [n_tokens, rows]
314  * @param rows Number of output rows
315  * @param cols Number of columns (must be multiple of 256)
316  * @param n_tokens Batch size
317  */
318 void ck_test_gemm_q5_k(const void *weight_q5_k,
319  const float *input_f32,
320  float *output,
321  int rows, int cols, int n_tokens);
322 
323 /**
324  * @brief Q5_1 GEMM - batched matrix multiply with Q5_1 weights (32-element blocks)
325  *
326  * Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
327  * Uses Q8_0 for activations.
328  *
329  * @param weight_q5_1 Q5_1 quantized weights [rows, cols]
330  * @param input_f32 FP32 input [n_tokens, cols]
331  * @param output FP32 output [n_tokens, rows]
332  * @param rows Number of output rows
333  * @param cols Number of columns (must be multiple of 32)
334  * @param n_tokens Batch size
335  */
336 void ck_test_gemm_q5_1(const void *weight_q5_1,
337  const float *input_f32,
338  float *output,
339  int rows, int cols, int n_tokens);
340 
341 /* ============================================================================
342  * Activation Kernels
343  * ============================================================================ */
344 
345 /**
346  * @brief RMSNorm
347  *
348  * Computes: output = (input / rms(input)) * weight
349  * where rms(x) = sqrt(mean(x^2) + eps)
350  *
351  * @param input Input tensor [n_tokens, dim]
352  * @param weight Normalization weights [dim]
353  * @param output Output tensor [n_tokens, dim]
354  * @param n_tokens Number of tokens
355  * @param dim Hidden dimension
356  * @param eps Epsilon for numerical stability
357  */
358 void ck_test_rmsnorm(const float *input,
359  const float *weight,
360  float *output,
361  int n_tokens, int dim, float eps);
362 
363 /**
364  * @brief RoPE (Rotary Position Embedding)
365  *
366  * Applies rotary position embeddings to Q and K tensors.
367  *
368  * NOTE: CK uses rotate-half format (split first/second halves)
369  * while some implementations use interleaved format.
370  * The test harness should account for this.
371  *
372  * @param q Query tensor [n_tokens, n_heads * head_dim], modified in-place
373  * @param k Key tensor [n_tokens, n_heads_kv * head_dim], modified in-place
374  * @param n_tokens Number of tokens
375  * @param n_heads Number of query heads
376  * @param n_heads_kv Number of key/value heads
377  * @param head_dim Dimension per head
378  * @param pos_offset Starting position for RoPE
379  * @param theta RoPE base frequency (typically 10000.0)
380  */
381 void ck_test_rope(float *q, float *k,
382  int n_tokens, int n_heads, int n_heads_kv, int head_dim,
383  int pos_offset, float theta);
384 
385 /**
386  * @brief RoPE with interleaved format (for llama.cpp compatibility)
387  *
388  * Uses interleaved format: (x0, x1) -> (x0*cos - x1*sin, x0*sin + x1*cos)
389  */
390 void ck_test_rope_interleaved(float *q, float *k,
391  int n_tokens, int n_heads, int n_heads_kv, int head_dim,
392  int pos_offset, float theta);
393 
394 /**
395  * @brief SwiGLU activation
396  *
397  * Computes: output = SiLU(gate) * up
398  * where SiLU(x) = x * sigmoid(x)
399  *
400  * @param gate_up Input tensor [n_tokens, 2 * intermediate_dim]
401  * Layout: [gate_0..gate_D-1, up_0..up_D-1] per token
402  * @param output Output tensor [n_tokens, intermediate_dim]
403  * @param n_tokens Number of tokens
404  * @param intermediate_dim Intermediate dimension
405  */
406 void ck_test_swiglu(const float *gate_up,
407  float *output,
408  int n_tokens, int intermediate_dim);
409 
410 /**
411  * @brief Softmax (simple, non-causal)
412  *
413  * Computes: output[i] = exp(input[i]) / sum(exp(input))
414  *
415  * @param input Input tensor [n]
416  * @param output Output tensor [n]
417  * @param n Number of elements
418  */
419 void ck_test_softmax(const float *input, float *output, int n);
420 
421 /* ============================================================================
422  * Attention Kernels
423  * ============================================================================ */
424 
425 /**
426  * @brief Multi-head causal attention for prefill (head-major layout)
427  *
428  * Layout (head-major, matches llama.cpp test):
429  * Q: [num_heads, tokens, head_dim]
430  * K: [num_kv_heads, seq_len, head_dim]
431  * V: [num_kv_heads, seq_len, head_dim]
432  * out: [num_heads, tokens, head_dim]
433  *
434  * Supports GQA (grouped-query attention) where num_heads > num_kv_heads.
435  * Causal masking: token t can only attend to positions 0..t (inclusive).
436  *
437  * @param q Query [num_heads, tokens, head_dim]
438  * @param k Key [num_kv_heads, seq_len, head_dim]
439  * @param v Value [num_kv_heads, seq_len, head_dim]
440  * @param out Output [num_heads, tokens, head_dim]
441  * @param num_heads Number of query heads
442  * @param num_kv_heads Number of key/value heads (for GQA)
443  * @param tokens Number of query tokens
444  * @param seq_len Key/value sequence length (for prefill: seq_len == tokens)
445  * @param head_dim Dimension per head
446  */
447 void ck_test_attention_causal(const float *q,
448  const float *k,
449  const float *v,
450  float *out,
451  int num_heads,
452  int num_kv_heads,
453  int tokens,
454  int seq_len,
455  int head_dim);
456 
457 /* ============================================================================
458  * Mega-Fused Kernels
459  * ============================================================================ */
460 
461 /**
462  * @brief Test mega-fused OutProj + MLP kernel (Q5_0 weights)
463  *
464  * This tests the mega_fused_outproj_mlp_prefill kernel which fuses:
465  * 1. Quantize attention output (head-major) to Q8_0
466  * 2. OutProj: attn_out @ W_o (Q5_0) → h1
467  * 3. Residual: h1 += residual
468  * 4. RMSNorm: h1 → ln2_out
469  * 5. MLP: silu(ln2_out @ W_gate) * (ln2_out @ W_up) @ W2
470  * 6. Residual: output += h1
471  *
472  * @param attn_out Attention output [num_heads, tokens, head_dim] (FP32, head-major)
473  * @param residual Residual input [tokens, embed_dim] (FP32)
474  * @param ln2_gamma RMSNorm gamma [embed_dim] (FP32)
475  * @param wo OutProj weights [embed_dim, embed_dim] (Q5_0)
476  * @param w1 MLP W1 weights [2*intermediate, embed_dim] (Q5_0)
477  * @param w2 MLP W2 weights [embed_dim, intermediate] (Q4_K or Q6_K)
478  * @param output Output [tokens, embed_dim] (FP32)
479  * @param tokens Number of tokens
480  * @param num_heads Number of attention heads
481  * @param head_dim Dimension per head
482  * @param embed_dim Embedding dimension (= num_heads * head_dim)
483  * @param intermediate MLP intermediate dimension
484  * @param eps RMSNorm epsilon
485  * @param w2_is_q6k If true, W2 is Q6_K; if false, W2 is Q4_K
486  */
488  const float *attn_out,
489  const float *residual,
490  const float *ln2_gamma,
491  const void *wo,
492  const void *w1,
493  const void *w2,
494  float *output,
495  int tokens,
496  int num_heads,
497  int head_dim,
498  int embed_dim,
499  int intermediate,
500  float eps,
501  int w2_is_q6k);
502 
503 /* ============================================================================
504  * Utility Functions
505  * ============================================================================ */
506 
507 /**
508  * @brief Get Q4_K block size in bytes
509  */
510 int ck_get_block_q4_k_size(void);
511 
512 /**
513  * @brief Get Q6_K block size in bytes
514  */
515 int ck_get_block_q6_k_size(void);
516 
517 /**
518  * @brief Get Q8_K block size in bytes
519  */
520 int ck_get_block_q8_k_size(void);
521 
522 /**
523  * @brief Get QK_K (elements per super-block)
524  */
525 int ck_get_qk_k(void);
526 
527 /**
528  * @brief Get Q5_K block size in bytes (176 bytes per 256 weights)
529  */
530 int ck_get_block_q5_k_size(void);
531 
532 /**
533  * @brief Get Q5_1 block size in bytes (24 bytes per 32 weights)
534  */
535 int ck_get_block_q5_1_size(void);
536 
537 /**
538  * @brief Get QK5_1 (elements per Q5_1 block)
539  */
540 int ck_get_qk5_1(void);
541 
542 #ifdef __cplusplus
543 }
544 #endif
545 
546 #endif /* CK_PARITY_API_H */
void ck_test_quantize_q8_k(const float *src, void *dst, int n)
Quantize FP32 to Q8_K (for activations)
void ck_test_gemm_q4_k(const void *weight_q4k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q4_K GEMM - batched matrix multiply with quantized weights.
int ck_get_block_q5_k_size(void)
Get Q5_K block size in bytes (176 bytes per 256 weights)
int ck_get_block_q8_k_size(void)
Get Q8_K block size in bytes.
void ck_test_dequant_q6_k(const void *src, float *dst, int n)
Dequantize Q6_K data to FP32.
void ck_test_rope(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE (Rotary Position Embedding)
int ck_get_qk5_1(void)
Get QK5_1 (elements per Q5_1 block)
void ck_test_gemv_q5_0_q8_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void ck_test_dequant_q4_0(const void *src, float *dst, int n)
Dequantize Q4_0 data to FP32.
void ck_test_softmax(const float *input, float *output, int n)
Softmax (simple, non-causal)
void ck_test_rmsnorm(const float *input, const float *weight, float *output, int n_tokens, int dim, float eps)
RMSNorm.
void ck_test_gemm_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q5_0 GEMM - batched matrix multiply with Q5_0 weights (32-element blocks)
void ck_test_dequant_q4_k(const void *src, float *dst, int n)
Dequantize Q4_K data to FP32.
void ck_test_attention_causal(const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)
Multi-head causal attention for prefill (head-major layout)
void ck_test_gemv_q4_k(const void *weight_q4k, const float *input_f32, float *output, int cols)
Q4_K GEMV - dot product of quantized weights and FP32 input.
void ck_test_gemv_q8_0_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void ck_test_vec_dot_q8_0_q8_0(const void *weight_q8_0, const void *input_q8_0, float *output, int cols)
Direct Q8_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
void ck_test_swiglu(const float *gate_up, float *output, int n_tokens, int intermediate_dim)
SwiGLU activation.
void ck_test_gemm_q6_k(const void *weight_q6k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q6_K GEMM - batched matrix multiply with Q6_K weights.
void ck_test_vec_dot_q5_0_q8_0(const void *weight_q5_0, const void *input_q8_0, float *output, int cols)
Direct Q5_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
void ck_test_gemm_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q5_1 GEMM - batched matrix multiply with Q5_1 weights (32-element blocks)
int ck_get_qk_k(void)
Get QK_K (elements per super-block)
void ck_test_gemm_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q8_0 GEMM - batched matrix multiply with Q8_0 weights (32-element blocks)
void ck_test_gemv_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols)
Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
void ck_test_gemv_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.
void ck_test_gemv_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols)
Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
void ck_test_rope_interleaved(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE with interleaved format (for llama.cpp compatibility)
void ck_test_dequant_q5_1(const void *src, float *dst, int n)
Dequantize Q5_1 data to FP32.
int ck_get_block_q4_k_size(void)
Get Q4_K block size in bytes.
int ck_get_block_q6_k_size(void)
Get Q6_K block size in bytes.
void ck_test_gemv_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.
void ck_test_gemm_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q5_K GEMM - batched matrix multiply with Q5_K weights (256-element super-blocks)
void ck_test_gemv_q6_k(const void *weight_q6k, const float *input_f32, float *output, int cols)
Q6_K GEMV.
void ck_test_outproj_mlp_fused_q5_0(const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const void *w1, const void *w2, float *output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)
Test mega-fused OutProj + MLP kernel (Q5_0 weights)
int ck_get_block_q5_1_size(void)
Get Q5_1 block size in bytes (24 bytes per 32 weights)