Fused attention decode kernel (legacy v6/v6.5) More...
Go to the source code of this file.
Functions | |
| void | ck_attention_project_head_major_decode_token (const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| static void | ck_attention_project_head_major_decode_token_residual (const float *attn_token, const float *wo, const float *bo, const float *residual_in, float *proj_out, float *residual_out, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| static float | ck_dot_f32 (const float *a, const float *b, int len) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_fused_attn (const CKLayerForwardParams *p, int token_index, int cache_capacity) |
| static void | ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl (const CKLayerForwardParams *p, int token_index, int cache_capacity, int fuse_mlp) |
| void | ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp (const CKLayerForwardParams *p, int token_index, int cache_capacity) |
| void | ck_qkv_project_head_major_token (const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim) |
Fused attention decode kernel (legacy v6/v6.5)
After changes: make test && make llamacpp-parity-full
LEGACY: This file is from v6/v6.5 and kept for backward compatibility.
Definition in file attention_decode_fused.c.
| void ck_attention_project_head_major_decode_token | ( | const float * | attn_token, |
| const float * | wo, | ||
| const float * | bo, | ||
| float * | out_token, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 115 of file attention_decode_fused.c.
References ck_dot_f32().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().
|
static |
Definition at line 145 of file attention_decode_fused.c.
References ck_dot_f32().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
|
inlinestatic |
Definition at line 41 of file attention_decode_fused.c.
Referenced by ck_attention_project_head_major_decode_token(), and ck_attention_project_head_major_decode_token_residual().
| void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn | ( | const CKLayerForwardParams * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 343 of file attention_decode_fused.c.
References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
|
static |
Definition at line 181 of file attention_decode_fused.c.
References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_decode_head_major_gqa_regular(), CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major_decode_token_residual(), ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_fully_fused_token(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.
Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp().
| void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp | ( | const CKLayerForwardParams * | p, |
| int | token_index, | ||
| int | cache_capacity | ||
| ) |
Definition at line 353 of file attention_decode_fused.c.
References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().
| void ck_qkv_project_head_major_token | ( | const float * | input_row, |
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| float * | q_token, | ||
| float * | k_token, | ||
| float * | v_token, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 78 of file attention_decode_fused.c.
References gemm_blocked_serial().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().