17 extern void dequant_q4_k_row(
const void *src,
float *dst,
size_t n_elements);
18 extern void dequant_q6_k_row(
const void *src,
float *dst,
size_t n_elements);
19 extern void dequant_q4_0_row(
const void *src,
float *dst,
size_t n_elements);
20 extern void dequant_q5_1_row(
const void *src,
float *dst,
size_t n_elements);
26 extern void gemv_q4_k_q8_k(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
28 float *
C,
int M,
int N,
int K);
31 extern void gemv_q6_k_q8_k(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
33 float *
C,
int M,
int N,
int K);
36 extern void gemm_nt_q8_0_q8_0(
const void *A_q8,
const void *B_q8,
const float *bias,
37 float *
C,
int M,
int N,
int K);
40 extern void gemm_nt_q5_0_q8_0(
const void *A_q8,
const void *B_q5,
const float *bias,
41 float *
C,
int M,
int N,
int K);
44 extern void gemv_q5_k(
float *y,
const void *W,
const float *x,
int M,
int K);
45 extern void gemm_nt_q5_k(
const float *A,
const void *B,
const float *bias,
46 float *
C,
int M,
int N,
int K);
49 extern void gemv_q5_1(
float *y,
const void *W,
const float *x,
int M,
int K);
50 extern void gemm_nt_q5_1(
const float *A,
const void *B,
const float *bias,
51 float *
C,
int M,
int N,
int K);
54 extern void gemv_q5_0(
float *y,
const void *W,
const float *x,
int M,
int K);
55 extern void gemv_q8_0(
float *y,
const void *W,
const float *x,
int M,
int K);
58 extern void gemv_q5_0_q8_0(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
59 extern void gemv_q8_0_q8_0(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
70 float *output,
float *rstd_cache,
71 int tokens,
int d_model,
int aligned_embed_dim,
float eps);
75 const float *cos_cache,
const float *sin_cache,
76 int num_heads,
int num_kv_heads,
int num_tokens,
77 int head_dim,
int aligned_head_dim,
int pos_offset);
79 int max_seq_len,
int head_dim,
float base);
82 extern void swiglu_forward(
const float *input,
float *output,
int tokens,
int dim);
86 const float *q,
const float *k,
const float *v,
float *output,
87 int num_heads,
int num_kv_heads,
int num_tokens,
88 int head_dim,
int aligned_head_dim,
int kv_stride_tokens);
92 const float *q,
const float *k,
const float *v,
float *output,
93 int num_heads,
int num_kv_heads,
int num_tokens,
94 int head_dim,
int aligned_head_dim,
int kv_stride_tokens,
98 const float *q_token,
const float *k_cache,
const float *v_cache,
99 float *out_token,
int num_heads,
int num_kv_heads,
100 int kv_tokens,
int cache_capacity,
int head_dim,
101 int aligned_head_dim,
int sliding_window);
106 int n_tokens,
int dim);
146 const float *input_f32,
168 const float *input_f32,
174 float *weight_f32 = (
float *)malloc(cols *
sizeof(
float));
184 for (
int i = 0; i < cols; i++) {
185 sum += (double)weight_f32[i] * (
double)input_f32[i];
187 *output = (float)sum;
193 const float *input_f32,
207 for (
int r = 0; r < rows; r++) output[r] = 0.0f;
221 const float *input_f32,
235 for (
int r = 0; r < rows; r++) output[r] = 0.0f;
249 const float *input_f32,
261 for (
int r = 0; r < rows; r++) output[r] = 0.0f;
275 const float *input_f32,
287 for (
int r = 0; r < rows; r++) output[r] = 0.0f;
302 const float *input_f32,
325 for (
int r = 0; r < rows; r++) {
334 const float *input_f32,
342 for (
int r = 0; r < rows; r++) {
365 const void *input_q8_0,
381 const void *input_q8_0,
393 const float *input_f32,
395 int rows,
int cols,
int n_tokens)
398 int n_blocks_per_row = cols /
CK_QK_K;
401 memset(output, 0, n_tokens * rows *
sizeof(
float));
406 for (
int t = 0; t < n_tokens; t++) {
408 q8_data + t * n_blocks_per_row, cols);
426 const float *input_f32,
428 int rows,
int cols,
int n_tokens)
431 int n_blocks_per_row = cols /
CK_QK_K;
434 memset(output, 0, n_tokens * rows *
sizeof(
float));
439 for (
int t = 0; t < n_tokens; t++) {
441 q8_data + t * n_blocks_per_row, cols);
459 const float *input_f32,
461 int rows,
int cols,
int n_tokens)
464 int n_blocks_per_row = cols /
CK_QK8_0;
467 memset(output, 0, n_tokens * rows *
sizeof(
float));
472 for (
int t = 0; t < n_tokens; t++) {
474 q8_data + t * n_blocks_per_row, cols);
492 const float *input_f32,
494 int rows,
int cols,
int n_tokens)
497 int n_blocks_per_row = cols /
CK_QK8_0;
500 memset(output, 0, n_tokens * rows *
sizeof(
float));
505 for (
int t = 0; t < n_tokens; t++) {
507 q8_data + t * n_blocks_per_row, cols);
526 const float *input_f32,
528 int rows,
int cols,
int n_tokens)
533 gemm_nt_q5_k(input_f32, weight_q5_k, NULL, output, n_tokens, rows, cols);
543 const float *input_f32,
545 int rows,
int cols,
int n_tokens)
550 gemm_nt_q5_1(input_f32, weight_q5_1, NULL, output, n_tokens, rows, cols);
560 int n_tokens,
int dim,
float eps)
568 int n_tokens,
int n_heads,
int n_heads_kv,
int head_dim,
569 int pos_offset,
float theta)
572 int half_dim = head_dim / 2;
573 int max_seq = pos_offset + n_tokens;
575 float *cos_cache = (
float *)malloc(max_seq * half_dim *
sizeof(
float));
576 float *sin_cache = (
float *)malloc(max_seq * half_dim *
sizeof(
float));
577 if (!cos_cache || !sin_cache) {
588 float *q_reorder = (
float *)malloc(n_heads * n_tokens * head_dim *
sizeof(
float));
589 float *k_reorder = (
float *)malloc(n_heads_kv * n_tokens * head_dim *
sizeof(
float));
591 if (q_reorder && k_reorder) {
593 for (
int t = 0; t < n_tokens; t++) {
594 for (
int h = 0; h < n_heads; h++) {
595 for (
int d = 0; d < head_dim; d++) {
596 q_reorder[h * n_tokens * head_dim + t * head_dim + d] =
597 q[t * n_heads * head_dim + h * head_dim + d];
603 for (
int t = 0; t < n_tokens; t++) {
604 for (
int h = 0; h < n_heads_kv; h++) {
605 for (
int d = 0; d < head_dim; d++) {
606 k_reorder[h * n_tokens * head_dim + t * head_dim + d] =
607 k[t * n_heads_kv * head_dim + h * head_dim + d];
614 cos_cache, sin_cache,
615 n_heads, n_heads_kv, n_tokens,
616 head_dim, head_dim, pos_offset);
619 for (
int t = 0; t < n_tokens; t++) {
620 for (
int h = 0; h < n_heads; h++) {
621 for (
int d = 0; d < head_dim; d++) {
622 q[t * n_heads * head_dim + h * head_dim + d] =
623 q_reorder[h * n_tokens * head_dim + t * head_dim + d];
628 for (
int t = 0; t < n_tokens; t++) {
629 for (
int h = 0; h < n_heads_kv; h++) {
630 for (
int d = 0; d < head_dim; d++) {
631 k[t * n_heads_kv * head_dim + h * head_dim + d] =
632 k_reorder[h * n_tokens * head_dim + t * head_dim + d];
645 int n_tokens,
int n_heads,
int n_heads_kv,
int head_dim,
646 int pos_offset,
float theta)
654 float *inv_freq = (
float *)malloc((head_dim / 2) *
sizeof(float));
655 if (!inv_freq)
return;
657 for (
int i = 0; i < head_dim / 2; i++) {
658 inv_freq[i] = 1.0f / powf(theta, (
float)(2 * i) / head_dim);
662 for (
int t = 0; t < n_tokens; t++) {
663 int pos = pos_offset + t;
664 for (
int h = 0; h < n_heads; h++) {
665 float *qh = q + t * n_heads * head_dim + h * head_dim;
667 for (
int i = 0; i < head_dim / 2; i++) {
668 float freq = pos * inv_freq[i];
669 float cos_val = cosf(freq);
670 float sin_val = sinf(freq);
673 float x0 = qh[i * 2];
674 float x1 = qh[i * 2 + 1];
675 qh[i * 2] = x0 * cos_val - x1 * sin_val;
676 qh[i * 2 + 1] = x0 * sin_val + x1 * cos_val;
682 for (
int t = 0; t < n_tokens; t++) {
683 int pos = pos_offset + t;
684 for (
int h = 0; h < n_heads_kv; h++) {
685 float *kh = k + t * n_heads_kv * head_dim + h * head_dim;
687 for (
int i = 0; i < head_dim / 2; i++) {
688 float freq = pos * inv_freq[i];
689 float cos_val = cosf(freq);
690 float sin_val = sinf(freq);
692 float x0 = kh[i * 2];
693 float x1 = kh[i * 2 + 1];
694 kh[i * 2] = x0 * cos_val - x1 * sin_val;
695 kh[i * 2 + 1] = x0 * sin_val + x1 * cos_val;
705 int n_tokens,
int intermediate_dim)
713 float max_val = input[0];
714 for (
int i = 1; i < n; i++) {
715 if (input[i] > max_val) max_val = input[i];
720 for (
int i = 0; i < n; i++) {
721 output[i] = expf(input[i] - max_val);
726 float inv_sum = 1.0f / sum;
727 for (
int i = 0; i < n; i++) {
728 output[i] *= inv_sum;
752 num_heads, num_kv_heads, tokens,
782 num_heads, num_kv_heads, tokens,
795 const float *k_cache,
796 const float *v_cache,
806 q_token, k_cache, v_cache, out_token,
807 num_heads, num_kv_heads,
808 kv_tokens, cache_capacity, head_dim, head_dim,
848 const float *attn_out,
849 const float *residual,
850 const float *ln2_gamma,
851 const void *wo,
const float *bo,
int wo_dt,
852 const void *w1,
const float *b1,
int w1_dt,
853 const void *w2,
const float *b2,
int w2_dt,
856 int aligned_embed_dim,
858 int aligned_head_dim,
859 int intermediate_dim,
860 int aligned_intermediate_dim,
866 int aligned_embed_dim,
868 int aligned_head_dim,
869 int aligned_intermediate_dim);
895 const float *attn_out,
896 const float *residual,
897 const float *ln2_gamma,
911 const int CK_DT_Q5_0_VAL = 11;
912 const int CK_DT_Q4_K_VAL = 7;
913 const int CK_DT_Q6_K_VAL = 8;
916 int aligned_embed_dim = embed_dim;
917 int aligned_head_dim = head_dim;
918 int aligned_intermediate = intermediate;
921 if ((intermediate % 256) != 0) {
922 aligned_intermediate = ((intermediate + 255) / 256) * 256;
927 tokens, aligned_embed_dim, num_heads, aligned_head_dim, aligned_intermediate);
929 void *scratch = malloc(scratch_size);
940 wo, NULL, CK_DT_Q5_0_VAL,
941 w1, NULL, CK_DT_Q5_0_VAL,
942 w2, NULL, w2_is_q6k ? CK_DT_Q6_K_VAL : CK_DT_Q4_K_VAL,
949 aligned_intermediate,
void ck_test_quantize_q8_k(const float *src, void *dst, int n)
Quantize FP32 to Q8_K (for activations)
void dequant_q4_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_0 row (multiple blocks)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void ck_test_gemm_q4_k(const void *weight_q4k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Q4_K GEMM - batched matrix multiply with quantized weights.
int ck_get_block_q5_k_size(void)
Get Q5_K block size in bytes (176 bytes per 256 weights)
int ck_get_block_q8_k_size(void)
Get Q8_K block size in bytes.
void ck_test_dequant_q6_k(const void *src, float *dst, int n)
Dequantize Q6_K data to FP32.
void ck_test_rope(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE (Rotary Position Embedding)
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void swiglu_forward(const float *input, float *output, int tokens, int dim)
int ck_get_qk5_1(void)
Get QK5_1 (elements per Q5_1 block)
void ck_test_gemv_q5_0_q8_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
void ck_test_dequant_q4_0(const void *src, float *dst, int n)
Dequantize Q4_0 data to FP32.
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
void ck_test_softmax(const float *input, float *output, int n)
Softmax (simple, non-causal)
void ck_test_rmsnorm(const float *input, const float *weight, float *output, int n_tokens, int dim, float eps)
RMSNorm.
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void ck_test_gemm_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q5_0 x Q8_0 GEMM (batch matrix multiply)
void ck_test_dequant_q4_k(const void *src, float *dst, int n)
Dequantize Q4_K data to FP32.
void ck_test_attention_causal(const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)
Multi-head causal attention for prefill (head-major layout)
void ck_test_gemv_q4_k(const void *weight_q4k, const float *input_f32, float *output, int cols)
Q4_K GEMV - dot product of quantized weights and FP32 input.
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, int wo_dt, const void *w1, const float *b1, int w1_dt, const void *w2, const float *b2, int w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
void quantize_row_q8_k(const float *x, void *vy, int k)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void ck_test_gemv_q8_0_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
void ck_test_vec_dot_q8_0_q8_0(const void *weight_q8_0, const void *input_q8_0, float *output, int cols)
Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
void ck_test_geglu(const float *x, float *out, int n_tokens, int dim)
Test GeGLU activation.
void ck_test_swiglu(const float *gate_up, float *output, int n_tokens, int intermediate_dim)
SwiGLU activation.
void ck_test_gemm_q6_k(const void *weight_q6k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q6_K x Q8_K GEMM (batch matrix multiply)
void ck_test_vec_dot_q5_0_q8_0(const void *weight_q5_0, const void *input_q8_0, float *output, int cols)
Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
void ck_test_gemm_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q5_1 x Q8_0 GEMM (batch matrix multiply)
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
int ck_get_qk_k(void)
Get QK_K (elements per super-block)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void ck_test_gemm_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q8_0 x Q8_0 GEMM (batch matrix multiply)
void ck_test_gemv_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols)
Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B_q8, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void ck_test_geglu_backward(const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
Test GeGLU backward.
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
void ck_test_attention_sliding_window(const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim, int sliding_window)
Test sliding-window attention (prefill)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
void ck_test_gemv_q5_0(const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.
void dequant_q5_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_1 row (multiple blocks)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
void ck_test_gemv_q5_1(const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols)
Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void ck_test_attention_decode_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int sliding_window)
Test sliding-window attention (decode mode)
void ck_test_rope_interleaved(float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE with interleaved format (for llama.cpp compatibility)
void ck_test_dequant_q5_1(const void *src, float *dst, int n)
Dequantize Q5_1 data to FP32.
int ck_get_block_q4_k_size(void)
Get Q4_K block size in bytes.
int ck_get_block_q6_k_size(void)
Get Q6_K block size in bytes.
void ck_test_gemv_q8_0(const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
void ck_test_gemm_q5_k(const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
Test Q5_K x Q8_K GEMM (batch matrix multiply)
void ck_test_gemv_q6_k(const void *weight_q6k, const float *input_f32, float *output, int cols)
Q6_K GEMV.
void ck_test_outproj_mlp_fused_q5_0(const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const void *w1, const void *w2, float *output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)
Test mega-fused OutProj + MLP kernel (Q5_0 weights)
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
int ck_get_block_q5_1_size(void)
Get Q5_1 block size in bytes (24 bytes per 32 weights)
C-Kernel-Engine Parity Testing API.
Quantization block structures for weight-only quantization.