14 #ifndef CK_PARITY_API_H
15 #define CK_PARITY_API_H
33 #define CK_BLOCK_Q4_K_SIZE 144
34 #define CK_BLOCK_Q6_K_SIZE 210
35 #define CK_BLOCK_Q8_K_SIZE 292
36 #define CK_BLOCK_Q4_0_SIZE 18
37 #define CK_BLOCK_Q5_K_SIZE 176
38 #define CK_BLOCK_Q5_1_SIZE 24
94 const float *input_f32,
102 const float *input_f32,
116 const float *input_f32,
130 const float *input_f32,
147 const float *input_f32,
164 const float *input_f32,
180 const float *input_f32,
196 const float *input_f32,
216 const void *input_q8_0,
229 const void *input_q8_0,
250 const float *input_f32,
252 int rows,
int cols,
int n_tokens);
267 const float *input_f32,
269 int rows,
int cols,
int n_tokens);
284 const float *input_f32,
286 int rows,
int cols,
int n_tokens);
301 const float *input_f32,
303 int rows,
int cols,
int n_tokens);
319 const float *input_f32,
321 int rows,
int cols,
int n_tokens);
337 const float *input_f32,
339 int rows,
int cols,
int n_tokens);
361 int n_tokens,
int dim,
float eps);
382 int n_tokens,
int n_heads,
int n_heads_kv,
int head_dim,
383 int pos_offset,
float theta);
391 int n_tokens,
int n_heads,
int n_heads_kv,
int head_dim,
392 int pos_offset,
float theta);
408 int n_tokens,
int intermediate_dim);
488 const float *attn_out,
489 const float *residual,
490 const float *ln2_gamma,
void ck_test_quantize_q8_k(const float *src, void *dst, int n)
Quantize FP32 to Q8_K (for activations)
void ck_test_gemm_q4_k(const void *weight_q4k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q4_K GEMM - batched matrix multiply with quantized weights.
int ck_get_block_q5_k_size(void)
Get Q5_K block size in bytes (176 bytes per 256 weights)
int ck_get_block_q8_k_size(void)
Get Q8_K block size in bytes.
void ck_test_dequant_q6_k(const void *src, float *dst, int n)
Dequantize Q6_K data to FP32.
void ck_test_rope(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE (Rotary Position Embedding)
int ck_get_qk5_1(void)
Get QK5_1 (elements per Q5_1 block)
void ck_test_gemv_q5_0_q8_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void ck_test_dequant_q4_0(const void *src, float *dst, int n)
Dequantize Q4_0 data to FP32.
void ck_test_softmax(const float *input, float *output, int n)
Softmax (simple, non-causal)
void ck_test_rmsnorm(const float *input, const float *weight, float *output, int n_tokens, int dim, float eps)
RMSNorm.
void ck_test_gemm_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q5_0 GEMM - batched matrix multiply with Q5_0 weights (32-element blocks)
void ck_test_dequant_q4_k(const void *src, float *dst, int n)
Dequantize Q4_K data to FP32.
void ck_test_attention_causal(const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)
Multi-head causal attention for prefill (head-major layout)
void ck_test_gemv_q4_k(const void *weight_q4k, const float *input_f32, float *output, int cols)
Q4_K GEMV - dot product of quantized weights and FP32 input.
void ck_test_gemv_q8_0_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void ck_test_vec_dot_q8_0_q8_0(const void *weight_q8_0, const void *input_q8_0, float *output, int cols)
Direct Q8_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
void ck_test_swiglu(const float *gate_up, float *output, int n_tokens, int intermediate_dim)
SwiGLU activation.
void ck_test_gemm_q6_k(const void *weight_q6k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q6_K GEMM - batched matrix multiply with Q6_K weights.
void ck_test_vec_dot_q5_0_q8_0(const void *weight_q5_0, const void *input_q8_0, float *output, int cols)
Direct Q5_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
void ck_test_gemm_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q5_1 GEMM - batched matrix multiply with Q5_1 weights (32-element blocks)
int ck_get_qk_k(void)
Get QK_K (elements per super-block)
void ck_test_gemm_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q8_0 GEMM - batched matrix multiply with Q8_0 weights (32-element blocks)
void ck_test_gemv_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols)
Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
void ck_test_gemv_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.
void ck_test_gemv_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols)
Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
void ck_test_rope_interleaved(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE with interleaved format (for llama.cpp compatibility)
void ck_test_dequant_q5_1(const void *src, float *dst, int n)
Dequantize Q5_1 data to FP32.
int ck_get_block_q4_k_size(void)
Get Q4_K block size in bytes.
int ck_get_block_q6_k_size(void)
Get Q6_K block size in bytes.
void ck_test_gemv_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.
void ck_test_gemm_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q5_K GEMM - batched matrix multiply with Q5_K weights (256-element super-blocks)
void ck_test_gemv_q6_k(const void *weight_q6k, const float *input_f32, float *output, int cols)
Q6_K GEMV.
void ck_test_outproj_mlp_fused_q5_0(const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const void *w1, const void *w2, float *output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)
Test mega-fused OutProj + MLP kernel (Q5_0 weights)
int ck_get_block_q5_1_size(void)
Get Q5_1 block size in bytes (24 bytes per 32 weights)