35 #ifndef CKERNEL_ORCHESTRATION_H
36 #define CKERNEL_ORCHESTRATION_H
183 int aligned_embed_dim);
194 const float *wq,
const float *bq,
195 const float *wk,
const float *bk,
196 const float *wv,
const float *bv,
197 float *q,
float *k,
float *v,
199 int kv_stride_tokens,
200 int aligned_embed_dim,
203 int aligned_head_dim);
206 const float *wq,
const float *bq,
207 const float *wk,
const float *bk,
208 const float *wv,
const float *bv,
212 int aligned_embed_dim,
215 int aligned_head_dim);
223 int aligned_embed_dim,
225 int aligned_head_dim);
232 int aligned_embed_dim,
234 int aligned_head_dim);
245 int aligned_embed_dim,
246 int aligned_intermediate_dim);
255 int aligned_embed_dim,
256 int aligned_intermediate_dim);
268 int aligned_embed_dim,
269 int aligned_intermediate_dim);
392 int aligned_embed_dim);
395 const float *attn_out,
401 int aligned_embed_dim,
403 int aligned_head_dim);
424 int aligned_embed_dim,
427 int aligned_head_dim,
CKDataType
Supported data types in C-Kernel-Engine.
void ck_layer_backward_rmsnorm_swiglu(const CKLayerBackwardParams *p)
void ck_mlp_swiglu_forward_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_attention_project_head_major(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_residual_add_backward(const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
void ck_layer_forward_rmsnorm_swiglu_quant(const CKLayerForwardParamsQ4K *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_layer_forward_rmsnorm_swiglu_decode(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_q4_k(const CKLayerForwardParamsQ4K *p)
void ck_qkv_project_head_major(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_qkv_project_head_major_backward(const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu(const CKLayerForwardParams *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_ref(const CKLayerForwardParams *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_attention_project_head_major_backward(const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu_decode_quant(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
int aligned_context_window
int aligned_intermediate_dim
int aligned_intermediate_dim
int aligned_context_window
int aligned_context_window
int aligned_intermediate_dim