19 #ifndef MEGA_FUSED_ATTENTION_H
20 #define MEGA_FUSED_ATTENTION_H
31 #ifndef MEGA_FUSE_Q_TILE
32 #define MEGA_FUSE_Q_TILE 64
35 #ifndef MEGA_FUSE_KV_TILE
36 #define MEGA_FUSE_KV_TILE 64
82 const float *residual,
83 const float *ln1_gamma,
84 const float *wq,
const float *bq,
85 const float *wk,
const float *bk,
86 const float *wv,
const float *bv,
87 const float *wo,
const float *bo,
90 const float *rope_cos,
91 const float *rope_sin,
94 int aligned_embed_dim,
136 const float *residual,
137 const float *ln1_gamma,
138 const void *wq,
const float *bq,
CKDataType wq_dt,
139 const void *wk,
const float *bk,
CKDataType wk_dt,
140 const void *wv,
const float *bv,
CKDataType wv_dt,
141 const void *wo,
const float *bo,
CKDataType wo_dt,
144 const float *rope_cos,
145 const float *rope_sin,
150 int aligned_embed_dim,
154 int aligned_head_dim,
167 const float *residual,
168 const float *ln1_gamma,
169 const void *wq,
const float *bq,
CKDataType wq_dt,
170 const void *wk,
const float *bk,
CKDataType wk_dt,
171 const void *wv,
const float *bv,
CKDataType wv_dt,
172 const void *wo,
const float *bo,
CKDataType wo_dt,
175 const float *rope_cos,
176 const float *rope_sin,
181 int aligned_embed_dim,
185 int aligned_head_dim,
192 int aligned_embed_dim,
194 int aligned_head_dim);
198 int aligned_embed_dim,
200 int aligned_head_dim);
209 const float *attn_out,
210 const float *residual,
211 const float *ln2_gamma,
212 const void *wo,
const float *bo,
CKDataType wo_dt,
213 const void *w1,
const float *b1,
CKDataType w1_dt,
214 const void *w2,
const float *b2,
CKDataType w2_dt,
217 int aligned_embed_dim,
219 int aligned_head_dim,
220 int intermediate_dim,
221 int aligned_intermediate_dim,
228 int aligned_embed_dim,
230 int aligned_head_dim,
231 int aligned_intermediate_dim);
267 const float *rope_cos,
268 const float *rope_sin,
CKDataType
Supported data types in C-Kernel-Engine.
size_t mega_fused_attention_prefill_q8_0_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
size_t mega_fused_attention_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill.
void mega_fused_attention_prefill(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused attention for prefill mode (multiple tokens)
void mega_fuse_report_stats(int hidden, int num_layers, int seq_len)
Report memory savings from mega-fusion.
void mega_fuse_rmsnorm_qkv_rope(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, const float *rope_cos, const float *rope_sin, int pos, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
Phase 2: Fused RMSNorm + QKV + RoPE.
void mega_fuse_get_optimal_tiles(int *q_tile, int *kv_tile, int head_dim)
Get optimal tile sizes for current CPU.
void mega_fused_attention_decode(float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
Mega-fused attention for decode mode (single token)
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
void mega_fused_attention_prefill_q8_0(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused prefill attention kernel (Q8_0 out-proj)
void mega_fuse_rmsnorm_qkv(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
Phase 1: Fused RMSNorm + QKV (intermediates in registers)