Go to the source code of this file.
Data Structures | |
| struct | CKLayerBackwardParams |
| struct | CKLayerForwardParams |
| struct | CKLayerForwardParamsQ4K |
Functions | |
| void | ck_attention_project_head_major (const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| void | ck_attention_project_head_major_backward (const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| void | ck_attention_project_head_major_decode_token (const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| void | ck_gemm_nt_quant (const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype) |
| void | ck_layer_backward_rmsnorm_swiglu (const CKLayerBackwardParams *p) |
| void | ck_layer_forward_rmsnorm_swiglu (const CKLayerForwardParams *p) |
| void | ck_layer_forward_rmsnorm_swiglu_decode (const CKLayerForwardParams *p, int token_index, int cache_capacity) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_fused (const CKLayerForwardParams *p, int token_index, int cache_capacity) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_fused_attn (const CKLayerForwardParams *p, int token_index, int cache_capacity) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp (const CKLayerForwardParams *p, int token_index, int cache_capacity) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_q4_k (const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_quant (const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity) |
| void | ck_layer_forward_rmsnorm_swiglu_q4_k (const CKLayerForwardParamsQ4K *p) |
| void | ck_layer_forward_rmsnorm_swiglu_quant (const CKLayerForwardParamsQ4K *p) |
| void | ck_layer_forward_rmsnorm_swiglu_ref (const CKLayerForwardParams *p) |
| void | ck_mlp_swiglu_forward (const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim) |
| void | ck_mlp_swiglu_forward_fully_fused_token (const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim) |
| void | ck_mlp_swiglu_forward_fused_token (const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim) |
| void | ck_qkv_project_head_major (const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim) |
| void | ck_qkv_project_head_major_backward (const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads) |
| void | ck_qkv_project_head_major_token (const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim) |
| void | ck_residual_add_backward (const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim) |
| void | ck_residual_add_token_major (const float *a, const float *b, float *out, int tokens, int aligned_embed_dim) |
This header declares v6.5 orchestration functions that are NO LONGER USED. v6.6 uses IR Lower 3 + codegen instead of hardcoded orchestration.
v6.6 Architecture (REPLACEMENT):
Deprecated functions (NOT used in v6.6):
To remove completely:
Last used: v6.5
Definition in file ckernel_orchestration.h.
| void ck_attention_project_head_major | ( | const float * | attn_out, |
| const float * | wo, | ||
| const float * | bo, | ||
| float * | out, | ||
| float * | scratch, | ||
| int | tokens, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 730 of file ckernel_orchestration.c.
References ck_add_inplace(), and gemm_blocked_serial().
Referenced by ck_attention_project_head_major_quant(), and ck_layer_forward_rmsnorm_swiglu().
| void ck_attention_project_head_major_backward | ( | const float * | d_out, |
| const float * | attn_out, | ||
| const float * | wo, | ||
| float * | d_attn_out, | ||
| float * | d_wo, | ||
| float * | d_bo, | ||
| int | tokens, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 800 of file ckernel_orchestration.c.
References fc2_backward_kernel().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void ck_attention_project_head_major_decode_token | ( | const float * | attn_token, |
| const float * | wo, | ||
| const float * | bo, | ||
| float * | out_token, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 115 of file attention_decode_fused.c.
References ck_dot_f32().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().
| void ck_gemm_nt_quant | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K, | ||
| CKDataType | dtype | ||
| ) |
Definition at line 335 of file ckernel_orchestration.c.
References C, CK_DT_FP32, CK_DT_Q4_0, CK_DT_Q4_1, CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q5_1, CK_DT_Q6_K, CK_DT_Q8_0, gemm_blocked_serial(), gemm_nt_q4_0(), gemm_nt_q4_1(), gemm_nt_q4_k(), gemm_nt_q5_0(), gemm_nt_q5_1(), gemm_nt_q6_k(), and gemm_nt_q8_0().
Referenced by ck_attention_project_head_major_quant(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_quant(), ck_qkv_project_head_major_token_quant(), and mega_fused_attention_prefill().
| void ck_layer_backward_rmsnorm_swiglu | ( | const CKLayerBackwardParams * | p | ) |
Definition at line 2677 of file ckernel_orchestration.c.
References CKLayerBackwardParams::aligned_context_window, CKLayerBackwardParams::aligned_embed_dim, CKLayerBackwardParams::aligned_head_dim, CKLayerBackwardParams::aligned_intermediate_dim, attention_backward_causal_head_major_gqa(), CKLayerBackwardParams::attn_out, CKLayerBackwardParams::bk, CKLayerBackwardParams::bq, CKLayerBackwardParams::bv, ck_add_inplace(), ck_attention_project_head_major_backward(), ck_qkv_project_head_major_backward(), ck_residual_add_backward(), CKLayerBackwardParams::d_attn_out, CKLayerBackwardParams::d_b1, CKLayerBackwardParams::d_b2, CKLayerBackwardParams::d_bk, CKLayerBackwardParams::d_bo, CKLayerBackwardParams::d_bq, CKLayerBackwardParams::d_bv, CKLayerBackwardParams::d_fc1_out, CKLayerBackwardParams::d_input, CKLayerBackwardParams::d_k, CKLayerBackwardParams::d_ln1_gamma, CKLayerBackwardParams::d_ln1_out, CKLayerBackwardParams::d_ln2_gamma, CKLayerBackwardParams::d_ln2_out, CKLayerBackwardParams::d_mlp_out, CKLayerBackwardParams::d_output, CKLayerBackwardParams::d_proj_tmp, CKLayerBackwardParams::d_q, CKLayerBackwardParams::d_residual1, CKLayerBackwardParams::d_scores, CKLayerBackwardParams::d_swiglu_out, CKLayerBackwardParams::d_v, CKLayerBackwardParams::d_w1, CKLayerBackwardParams::d_w2, CKLayerBackwardParams::d_wk, CKLayerBackwardParams::d_wo, CKLayerBackwardParams::d_wq, CKLayerBackwardParams::d_wv, CKLayerBackwardParams::embed_dim, fc1_backward_kernel(), CKLayerBackwardParams::fc1_out, fc2_backward_kernel(), CKLayerBackwardParams::head_dim, CKLayerBackwardParams::input, CKLayerBackwardParams::k, CKLayerBackwardParams::ln1_gamma, CKLayerBackwardParams::ln1_out, CKLayerBackwardParams::ln1_rstd, CKLayerBackwardParams::ln2_gamma, CKLayerBackwardParams::ln2_out, CKLayerBackwardParams::ln2_rstd, CKLayerBackwardParams::num_heads, CKLayerBackwardParams::num_kv_heads, CKLayerBackwardParams::q, CKLayerBackwardParams::residual1, rmsnorm_backward(), rope_backward_qk(), CKLayerBackwardParams::rope_cos, CKLayerBackwardParams::rope_pos_offset, CKLayerBackwardParams::rope_sin, CKLayerBackwardParams::scores, swiglu_backward(), CKLayerBackwardParams::swiglu_out, CKLayerBackwardParams::tokens, CKLayerBackwardParams::v, CKLayerBackwardParams::w1, CKLayerBackwardParams::w2, CKLayerBackwardParams::wk, CKLayerBackwardParams::wo, CKLayerBackwardParams::wq, and CKLayerBackwardParams::wv.
| void ck_layer_forward_rmsnorm_swiglu | ( | const CKLayerForwardParams * | p | ) |
Definition at line 996 of file ckernel_orchestration.c.
References CKLayerForwardParams::aligned_context_window, CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParams::attn_out, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major(), ck_mlp_swiglu_forward(), ck_qkv_project_head_major(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_scratch, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::q, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::scores, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::tokens, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.
| void ck_layer_forward_rmsnorm_swiglu_decode | ( | const CKLayerForwardParams * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 1289 of file ckernel_orchestration.c.
References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), ck_mlp_swiglu_forward(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.
| void ck_layer_forward_rmsnorm_swiglu_decode_fused | ( | const CKLayerForwardParams * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 1449 of file ckernel_orchestration.c.
References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), ck_mlp_swiglu_forward_fully_fused_token(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.
| void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn | ( | const CKLayerForwardParams * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 343 of file attention_decode_fused.c.
References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
| void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp | ( | const CKLayerForwardParams * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 353 of file attention_decode_fused.c.
References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
| void ck_layer_forward_rmsnorm_swiglu_decode_q4_k | ( | const CKLayerForwardParamsQ4K * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 2117 of file ckernel_orchestration.c.
References CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_decode_head_major_gqa_regular(), CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_flash_decode_wrapper(), ck_debug_check_buffer(), ck_debug_check_q4k_weights(), ck_debug_check_q8k(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_q8k_activations_enabled(), ck_qkv_project_head_major_token_q4_k(), ck_qkv_project_head_major_token_q4_k_q8_k(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, gemm_nt_q4_k(), gemm_nt_q4_k_q8_k(), CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, kv_cache_write_head_major(), CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_tmp, QK_K, quantize_row_q8_k(), CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wq, and CKLayerForwardParamsQ4K::wv.
| void ck_layer_forward_rmsnorm_swiglu_decode_quant | ( | const CKLayerForwardParamsQ4K * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 2512 of file ckernel_orchestration.c.
References CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), CK_DT_FP32, ck_gemm_nt_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_token_quant(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, kv_cache_write_head_major(), CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w1_dtype, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::w2_dtype, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wk_dtype, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wo_dtype, CKLayerForwardParamsQ4K::wq, CKLayerForwardParamsQ4K::wq_dtype, CKLayerForwardParamsQ4K::wv, and CKLayerForwardParamsQ4K::wv_dtype.
| void ck_layer_forward_rmsnorm_swiglu_q4_k | ( | const CKLayerForwardParamsQ4K * | p | ) |
Definition at line 1910 of file ckernel_orchestration.c.
References CKLayerForwardParamsQ4K::aligned_context_window, CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParamsQ4K::attn_out, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_project_head_major_q4_k(), ck_attention_project_head_major_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_q8k_activations_enabled(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_q4_k_q8_k(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::q, QK_K, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::scores, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::tokens, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wq, and CKLayerForwardParamsQ4K::wv.
| void ck_layer_forward_rmsnorm_swiglu_quant | ( | const CKLayerForwardParamsQ4K * | p | ) |
Definition at line 2401 of file ckernel_orchestration.c.
References CKLayerForwardParamsQ4K::aligned_context_window, CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParamsQ4K::attn_out, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_project_head_major_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_quant(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::q, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::scores, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::tokens, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w1_dtype, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::w2_dtype, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wk_dtype, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wo_dtype, CKLayerForwardParamsQ4K::wq, CKLayerForwardParamsQ4K::wq_dtype, CKLayerForwardParamsQ4K::wv, and CKLayerForwardParamsQ4K::wv_dtype.
| void ck_layer_forward_rmsnorm_swiglu_ref | ( | const CKLayerForwardParams * | p | ) |
Definition at line 1104 of file ckernel_orchestration.c.
References CKLayerForwardParams::aligned_context_window, CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParams::attn_out, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_scratch, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::q, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::scores, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::tokens, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.
| void ck_mlp_swiglu_forward | ( | const float * | input, |
| const float * | w1, | ||
| const float * | b1, | ||
| const float * | w2, | ||
| const float * | b2, | ||
| float * | fc1_out, | ||
| float * | swiglu_out, | ||
| float * | output, | ||
| int | tokens, | ||
| int | aligned_embed_dim, | ||
| int | aligned_intermediate_dim | ||
| ) |
Definition at line 952 of file ckernel_orchestration.c.
References gemm_blocked_serial(), and swiglu_forward().
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
| void ck_mlp_swiglu_forward_fully_fused_token | ( | const float * | input_row, |
| const float * | w1, | ||
| const float * | b1, | ||
| const float * | w2, | ||
| const float * | b2, | ||
| float * | output_row, | ||
| int | aligned_embed_dim, | ||
| int | aligned_intermediate_dim | ||
| ) |
Definition at line 1247 of file ckernel_orchestration.c.
References fused_mlp_swiglu_decode_v2().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
| void ck_mlp_swiglu_forward_fused_token | ( | const float * | input_row, |
| const float * | w1, | ||
| const float * | b1, | ||
| const float * | w2, | ||
| const float * | b2, | ||
| float * | swiglu_row, | ||
| float * | output_row, | ||
| int | aligned_embed_dim, | ||
| int | aligned_intermediate_dim | ||
| ) |
Definition at line 1212 of file ckernel_orchestration.c.
References gemm_blocked_serial(), and gemm_swiglu_fused().
| void ck_qkv_project_head_major | ( | const float * | input, |
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| float * | q, | ||
| float * | k, | ||
| float * | v, | ||
| int | tokens, | ||
| int | kv_stride_tokens, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 168 of file ckernel_orchestration.c.
References gemm_blocked_serial().
Referenced by ck_layer_forward_rmsnorm_swiglu().
| void ck_qkv_project_head_major_backward | ( | const float * | d_q, |
| const float * | d_k, | ||
| const float * | d_v, | ||
| const float * | input, | ||
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| float * | d_input, | ||
| float * | d_wq, | ||
| float * | d_bq, | ||
| float * | d_wk, | ||
| float * | d_bk, | ||
| float * | d_wv, | ||
| float * | d_bv, | ||
| float * | scratch, | ||
| int | tokens, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | aligned_head_dim, | ||
| int | num_threads | ||
| ) |
Definition at line 856 of file ckernel_orchestration.c.
References ck_add_inplace(), and fc2_backward_kernel().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void ck_qkv_project_head_major_token | ( | const float * | input_row, |
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| float * | q_token, | ||
| float * | k_token, | ||
| float * | v_token, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 78 of file attention_decode_fused.c.
References gemm_blocked_serial().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
| void ck_residual_add_backward | ( | const float * | d_out, |
| float * | d_a, | ||
| float * | d_b, | ||
| int | tokens, | ||
| int | aligned_embed_dim | ||
| ) |
Definition at line 151 of file ckernel_orchestration.c.
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void ck_residual_add_token_major | ( | const float * | a, |
| const float * | b, | ||
| float * | out, | ||
| int | tokens, | ||
| int | aligned_embed_dim | ||
| ) |
Definition at line 139 of file ckernel_orchestration.c.
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().