1 #ifndef CKERNEL_ENGINE_H
2 #define CKERNEL_ENGINE_H
27 void (*sgemm)(
int M,
int N,
int K,
28 const float *A,
int lda,
29 const float *B,
int ldb,
101 int M,
int N,
int K);
107 int M,
int N,
int K);
119 int M,
int N,
int K);
125 int M,
int N,
int K);
128 void gemm_nt_q4_0(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
129 void gemm_nt_q4_1(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
130 void gemm_nt_q5_0(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
131 void gemm_nt_q5_1(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
132 void gemm_nt_q5_k(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
133 void gemm_nt_q8_0(
const float *A,
const void *B,
const float *bias,
float *
C,
int M,
int N,
int K);
136 void gemv_q4_0(
float *y,
const void *W,
const float *x,
int M,
int K);
137 void gemv_q5_0(
float *y,
const void *W,
const float *x,
int M,
int K);
138 void gemv_q5_1(
float *y,
const void *W,
const float *x,
int M,
int K);
139 void gemv_q5_k(
float *y,
const void *W,
const float *x,
int M,
int K);
140 void gemv_q8_0(
float *y,
const void *W,
const float *x,
int M,
int K);
144 int M,
int K,
int ith,
int nth);
146 int M,
int K,
int ith,
int nth);
169 void gemv_q5_0_q8_0(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
172 void gemv_q8_0_q8_0(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
176 const float *bias,
int M,
int K);
180 const float *bias,
int M,
int K);
219 int M,
int N,
int K);
225 int M,
int N,
int K);
237 int M,
int K,
int ith,
int nth);
239 int M,
int K,
int ith,
int nth);
244 int M,
int N,
int K);
250 int M,
int N,
int K);
256 int M,
int N,
int K);
265 int M,
int N,
int K);
271 int M,
int N,
int K);
277 int M,
int N,
int K);
306 int M,
int N,
int K);
312 int M,
int N,
int K);
318 int M,
int N,
int K);
325 int M,
int N,
int K);
331 int M,
int N,
int K);
337 int M,
int N,
int K);
348 int M,
int N,
int K);
443 const float *Wq,
const float *Bq,
444 const float *Wk,
const float *Bk,
445 const float *Wv,
const float *Bv,
451 int aligned_embed_dim,
455 int aligned_head_dim,
456 int kv_stride_tokens,
468 const void *Wq,
const float *Bq,
CKDataType wq_dt,
469 const void *Wk,
const float *Bk,
CKDataType wk_dt,
470 const void *Wv,
const float *Bv,
CKDataType wv_dt,
476 int aligned_embed_dim,
480 int aligned_head_dim,
481 int kv_stride_tokens,
563 int aligned_embed_dim,
564 int intermediate_dim,
565 int aligned_intermediate_dim,
570 int aligned_intermediate_dim);
586 int M,
int N,
int K);
592 int M,
int N,
int K);
599 int M,
int N,
int K);
608 int tokens,
int d_model,
int aligned_embed_dim,
612 const float *__restrict gamma,
613 const float *__restrict beta,
614 float *__restrict output_slice_base,
615 float *__restrict mean_cache_slice,
616 float *__restrict rstd_cache_slice,
617 int num_tokens_in_slice,
619 int aligned_embed_dim,
624 const float *__restrict gamma,
625 const float *__restrict beta,
626 uint16_t *__restrict output_slice_base,
627 float *__restrict mean_cache_slice,
628 float *__restrict rstd_cache_slice,
629 int num_tokens_in_slice,
631 int aligned_embed_dim,
633 float *scratch_input,
634 float *scratch_output);
637 const float *__restrict gamma,
638 const float *__restrict beta,
639 float *__restrict output_slice_base,
640 float *__restrict mean_cache_slice,
641 float *__restrict rstd_cache_slice,
642 int num_tokens_in_slice,
648 const float *__restrict gamma,
649 const float *__restrict beta,
650 uint16_t *__restrict output_slice_base,
651 float *__restrict mean_cache_slice,
652 float *__restrict rstd_cache_slice,
653 int num_tokens_in_slice,
656 float *scratch_input,
657 float *scratch_output);
665 int tokens,
int d_model,
float eps);
675 int tokens,
int d_model,
int aligned_embed_dim);
679 const uint16_t *input,
686 int tokens,
int d_model,
int aligned_embed_dim,
687 float *scratch_d_output,
688 float *scratch_input,
689 float *scratch_d_input);
698 int aligned_embed_dim,
704 const float *rstd_cache,
709 int aligned_embed_dim);
717 int aligned_embed_dim,
721 const uint16_t *input,
723 const float *rstd_cache,
728 int aligned_embed_dim);
737 int aligned_embed_dim,
739 float *scratch_input,
740 float *scratch_output);
746 const float *rstd_cache,
751 int aligned_embed_dim,
752 float *scratch_d_output,
753 float *scratch_input,
754 float *scratch_d_input);
763 int aligned_embed_dim,
765 float *scratch_input,
766 float *scratch_output);
770 const uint8_t *input,
772 const float *rstd_cache,
777 int aligned_embed_dim,
778 float *scratch_d_output,
779 float *scratch_input,
780 float *scratch_d_input);
791 const float *d_output,
798 const float *d_output,
803 const float *d_output,
811 const uint16_t *d_output,
814 float *scratch_input,
815 float *scratch_d_output,
816 float *scratch_d_input);
818 const uint16_t *d_output,
821 float *scratch_input,
822 float *scratch_d_output,
823 float *scratch_d_input);
828 void geglu_forward_bf16(
const uint16_t *x, uint16_t *out,
int tokens,
int dim,
float *scratch);
836 void relu_forward(
const float *input,
float *output,
size_t n);
839 const float *d_output,
846 const uint16_t *d_output,
854 int aligned_context_window);
861 int aligned_context_window);
864 const float *weights,
867 int aligned_context_window);
873 int aligned_context_window,
878 const uint16_t *weights,
881 int aligned_context_window,
882 float *scratch_d_scores,
883 float *scratch_weights);
897 int aligned_head_dim,
898 int aligned_context_window);
909 int aligned_head_dim,
910 int aligned_context_window);
922 int aligned_head_dim,
923 int aligned_context_window);
935 int aligned_head_dim,
936 int aligned_context_window);
948 int aligned_head_dim,
949 int aligned_context_window,
967 int aligned_head_dim);
977 int aligned_head_dim,
978 int kv_stride_tokens);
985 const float *k_cache,
986 const float *v_cache,
993 int aligned_head_dim);
1001 const float *k_cache,
1002 const float *v_cache,
1009 int aligned_head_dim);
1023 int aligned_head_dim,
1024 int kv_stride_tokens,
1025 int sliding_window);
1030 const float *q_token,
1031 const float *k_cache,
1032 const float *v_cache,
1039 int aligned_head_dim,
1040 int sliding_window);
1068 const float *k_cache,
1069 const float *v_cache,
1076 int aligned_head_dim);
1080 const float *__restrict v_token,
1081 float *__restrict k_cache,
1082 float *__restrict v_cache,
1087 int aligned_head_dim);
1090 float *__restrict kv_cache_v,
1091 const float *__restrict k,
1092 const float *__restrict v,
1108 int aligned_head_dim);
1137 const uint16_t *W_fc1,
1138 const uint16_t *b_fc1,
1139 const uint16_t *W_fc2,
1140 const uint16_t *b_fc2,
1146 float *scratch_bias1_f,
1147 float *scratch_bias2_f,
1148 uint16_t *scratch_fc1_bf16);
1152 const uint16_t *W_fc1,
1153 const uint16_t *b_fc1,
1154 const uint16_t *W_fc2,
1155 const uint16_t *b_fc2,
1161 float *scratch_input_f,
1162 float *scratch_bias1_f,
1163 float *scratch_bias2_f,
1164 uint16_t *scratch_fc1_bf16);
1168 const float *fc2_input,
1179 const float *fc1_input,
1197 const float *d_output,
1205 float *scratch_input,
1206 float *scratch_output);
1209 const uint16_t *d_output,
1212 float *scratch_input,
1213 float *scratch_d_output,
1214 float *scratch_d_input);
1225 const float *d_output,
1237 const float *d_output,
1248 const uint16_t *d_output,
1323 const float **vectors,
1324 const float *weights,
1345 const float *expert_output,
1346 float routing_weight,
1380 const float *d_output,
1384 const float *attn_weights,
1393 int aligned_head_dim,
1394 int aligned_context_window);
1398 const float *d_output,
1402 const float *attn_weights,
1410 int aligned_head_dim,
1411 int aligned_context_window);
1415 const uint16_t *d_output,
1420 const float *attn_weights,
1429 int aligned_head_dim,
1430 int aligned_context_window,
1431 float *scratch_d_output,
1446 const float *cos_cache,
1447 const float *sin_cache,
1451 int aligned_head_dim,
1457 const float *cos_cache,
1458 const float *sin_cache,
1462 int aligned_head_dim,
1467 const float *cos_cache,
1468 const float *sin_cache,
1472 int aligned_head_dim,
1479 const float *cos_cache,
1480 const float *sin_cache,
1484 int aligned_head_dim,
1486 float *scratch_d_out,
1487 float *scratch_d_x);
1491 const float *cos_cache,
1492 const float *sin_cache,
1496 int aligned_head_dim,
1500 const float *cos_cache,
1501 const float *sin_cache,
1505 int aligned_head_dim,
1507 int head_stride_tokens);
1512 const float *cos_cache,
1513 const float *sin_cache,
1518 int aligned_head_dim,
1523 const float *cos_cache,
1524 const float *sin_cache,
1529 int aligned_head_dim,
1531 int q_stride_tokens,
1532 int k_stride_tokens);
1535 const float *d_k_out,
1538 const float *cos_cache,
1539 const float *sin_cache,
1544 int aligned_head_dim,
1550 const float *cos_cache,
1551 const float *sin_cache,
1556 int aligned_head_dim,
1563 const uint16_t *d_k_out,
1566 const float *cos_cache,
1567 const float *sin_cache,
1572 int aligned_head_dim,
1574 float *scratch_dq_out,
1576 float *scratch_dk_out,
1586 const float *token_embeddings,
1587 const float *pos_embeddings,
1590 int aligned_embed_dim,
1597 const void *token_embeddings,
1598 const float *pos_embeddings,
1601 int aligned_embed_dim,
1608 const void *token_embeddings,
1609 const float *pos_embeddings,
1612 int aligned_embed_dim,
1619 const void *token_embeddings,
1620 const float *pos_embeddings,
1623 int aligned_embed_dim,
1630 const uint16_t *token_embeddings,
1631 const uint16_t *pos_embeddings,
1634 int aligned_embed_dim,
1644 const float *d_output,
1645 float *d_token_embeddings,
1646 float *d_pos_embeddings,
1649 int aligned_embed_dim,
1655 const uint16_t *d_output,
1656 uint16_t *d_token_embeddings,
1657 uint16_t *d_pos_embeddings,
1660 int aligned_embed_dim,
1667 const int32_t *targets,
1675 const int32_t *targets,
1680 float *scratch_logits,
1681 float *scratch_d_logits);
1686 int C,
int H,
int W,
int P);
1687 void patch2im(
const float *d_patches,
1689 int C,
int H,
int W,
int P);
1693 int C,
int H,
int W,
int P);
1696 int C,
int H,
int W,
int P);
CKDataType
Supported data types in C-Kernel-Engine.
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void attention_forward_causal_head_major_gqa_bf16(const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v)
void dequant_q4_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_0 row (multiple blocks)
void embedding_forward_q6_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void attention_forward_causal_head_major_gqa_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void axpy_f32(float *y, const float *x, float alpha, int n)
In-place AXPY: y += alpha * x.
void rmsnorm_forward_int8(const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
void gemm_q6_k(float *Y, const void *W, const float *X, int M, int N, int K)
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
int ck_flash_attn_choose_tile_k(int D_h)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
void moe_accumulate_expert_f32(float *output, const float *expert_output, float routing_weight, int hidden_dim)
Accumulate expert output: output += routing_weight * expert_output.
void swiglu_forward_exact(const float *input, float *output, int tokens, int dim)
void rmsnorm_backward_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemv_fused_q8_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_bias_silu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void backward_causal_softmax_head_major_bf16(uint16_t *d_scores, const uint16_t *weights, int num_heads, int num_tokens, int aligned_context_window, float *scratch_d_scores, float *scratch_weights)
void add_scaled_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, float alpha, size_t n)
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
void ck_set_num_threads(int num_threads)
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void gelu_backward_exact(const float *input, const float *d_output, float *d_input, size_t n)
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void axpy_zero_f32(float *y, const float *x, float alpha, int n)
Zero output then accumulate: y = 0; y += alpha * x.
void fused_mlp_swiglu_prefill(const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
void gelu_exact_inplace(float *data, size_t n)
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void add_inplace_bf16(uint16_t *a, const uint16_t *b, size_t n)
void attention_backward_causal_head_major_gqa_bf16(const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v)
void gemm_nn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void fused_mlp_swiglu_decode_tiled(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void kv_cache_repack_head_major_inplace(float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
void fused_mlp_swiglu_decode(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void gemm_bias_relu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void layernorm_naive_serial_matched_precision(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, float eps)
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
void gemv_q6_k(float *y, const void *W, const float *x, int M, int K)
void topk_batched_f32(const float *scores, int num_tokens, int n_experts, int k, int *indices, float *weights)
Batched top-K selection for multiple tokens.
void backward_causal_softmax_head_major(float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
void attention_forward_causal_head_major(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void patch2im(const float *d_patches, float *d_image, int C, int H, int W, int P)
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void swiglu_forward_bf16(const uint16_t *input, uint16_t *output, int tokens, int dim)
void rope_backward_bf16(const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void relu_backward(const float *input, const float *d_output, float *d_input, size_t n)
void mlp_token_parallel_bf16(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int tokens, int dim)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void dequant_q5_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
void add_forward_f32(const float *a, const float *b, float *y, size_t n)
void rmsnorm_forward_bf16(const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
void add_inplace_f32(float *a, const float *b, size_t n)
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
void relu_forward_inplace_bf16(uint16_t *data, size_t n)
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void im2patch_bf16(const uint16_t *image, uint16_t *patches, int C, int H, int W, int P)
void attention_forward_causal_head_major_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
void mlp_token_parallel(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
void softmax_cross_entropy_loss_bf16(const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits)
void rope_backward_qk_bf16(const uint16_t *d_q_out, const uint16_t *d_k_out, uint16_t *d_q, uint16_t *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_dq_out, float *scratch_dq, float *scratch_dk_out, float *scratch_dk)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void weighted_sum_f32(float *y, const float **vectors, const float *weights, int k, int n)
Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])
void rmsnorm_backward_int4(const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
CKMathBackend ckernel_backend_native(void)
void embedding_forward_bf16(const int32_t *token_ids, int token_count, int vocab_size, const uint16_t *token_embeddings, const uint16_t *pos_embeddings, uint16_t *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gelu_backward_fast_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void gemm_q6_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_backward_inplace(float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void layernorm_naive_serial(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void swiglu_backward_exact(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void swiglu_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, int tokens, int dim)
void mlp_token_parallel_bf16_fp32act(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void gelu_backward_exact_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
void embedding_backward_bf16(const int32_t *token_ids, int token_count, const uint16_t *d_output, uint16_t *d_token_embeddings, uint16_t *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
void embedding_backward(const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
int ck_get_physical_cores(void)
void mlp_token_parallel_exact(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
void relu_forward_bf16(const uint16_t *input, uint16_t *output, size_t n)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void gemm_tn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void relu_forward(const float *input, float *output, size_t n)
void im2patch(const float *image, float *patches, int C, int H, int W, int P)
void fused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill.
void quantize_batch_q8_k(const float *x, void *y, int num_rows, int k)
Batch quantize FP32 to Q8_K format (row-major output)
void vec_dot_q6_k_q8_k(int n, float *s, const void *vx, const void *vy)
Q6_K x Q8_K dot product (single row)
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
void quantize_row_q8_k(const float *x, void *y, int k)
void axpy_2d_f32(float *Y, const float *X, float alpha, int num_tokens, int dim, int y_stride, int x_stride)
Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].
void ck_set_strict_parity(int enabled)
void rmsnorm_backward_int8(const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
void sigmoid_forward_bf16(const uint16_t *input, uint16_t *output, size_t n, float *scratch_input, float *scratch_output)
void gemm_blocked_serial_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
void sigmoid_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
int argmax_f32(const float *scores, int n)
Find index of maximum value.
void unfused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
Unfused version for benchmarking comparison.
void add_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, size_t n)
void add_forward_2d_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, int tokens, int dim, int aligned_dim)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemm_avx512_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
void gelu_fast_inplace_bf16(uint16_t *data, size_t n, float *scratch)
void dequant_q8_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)
void rope_forward_bf16(uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void layernorm_backward_kernel_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *mean, const float *rstd, uint16_t *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void gemv_q6_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q6_K × Q8_K.
float sigmoid_scalar(float x)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void ck_attention_flash_decode_wrapper(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
void sigmoid_backward(const float *input, const float *d_output, float *d_input, size_t n)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
size_t fused_rmsnorm_qkv_scratch_size(int hidden)
Get scratch buffer size for fused_rmsnorm_qkv_prefill.
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void gemm_fine_grained_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void relu_forward_inplace(float *data, size_t n)
void gelu_fast_inplace(float *data, size_t n)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q5_0_parallel(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel reference GEMV for Q5_0 × FP32.
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void geglu_forward_bf16(const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch)
void gemm_q4_k(float *Y, const void *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
void gelu_backward_scalar(const float *input, const float *d_output, float *d_input, size_t n)
void rmsnorm_forward_int4(const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
void gemm_q4_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K)
void dequant_q5_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_1 row (multiple blocks)
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void dequant_q4_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_1 row (multiple blocks)
void gemv_fused_q5_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
void fused_mlp_swiglu_prefill_bias(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
void embedding_forward_q8_0(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
int ck_flash_attn_fast_exp_kind(void)
void attention_backward_causal_head_major(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void embedding_forward(const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float *scratch_input, float *scratch_output)
void gemm_bias_gelu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
void rope_forward_qk_bf16(uint16_t *q, uint16_t *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_q, float *scratch_k)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void topk_f32(const float *scores, int n, int k, int *indices, float *values)
Find top-K indices and values from a score vector.
void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K)
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
void gemv_q6_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q6_K × Q8_K.
void gemm_tn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void topk_softmax_f32(const float *scores, int n, int k, int *indices, float *weights)
Find top-K indices with softmax-normalized weights.
void gemv_q5_0_parallel_simd(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.
void kv_cache_store(float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len)
void gelu_backward_fast(const float *input, const float *d_output, float *d_input, size_t n)
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
void causal_softmax_head_major_bf16(uint16_t *scores, int num_heads, int num_tokens, int aligned_context_window, float *scratch)
int ck_get_num_threads(void)
void patch2im_bf16(const uint16_t *d_patches, uint16_t *d_image, int C, int H, int W, int P)
int ck_strict_parity_enabled(void)
void add_backward_bf16(const uint16_t *d_y, uint16_t *d_a, uint16_t *d_b, size_t n)
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
void scal_copy_f32(float *y, const float *x, float alpha, int n)
Scaled copy: y = alpha * x.
void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void rope_forward_strided(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
size_t fused_mlp_swiglu_scratch_size(int intermediate)
Get scratch buffer size for fused_mlp_swiglu_prefill.
void sigmoid_forward(const float *input, float *output, size_t n)
void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
void relu_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n)
void quantize_batch_q8_0(const float *x, void *y, int num_rows, int k)
Batch quantize FP32 to Q8_0 format (row-major output)
void gemv_q4_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
void gemm_nn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void add_scaled_inplace_bf16(uint16_t *a, const uint16_t *b, float alpha, size_t n)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Quantization block structures for weight-only quantization.
Mega-Fused Attention Kernel.