Mega-Fused Attention Kernel. More...
Go to the source code of this file.
Macros | |
| #define | MEGA_FUSE_KV_TILE 64 |
| #define | MEGA_FUSE_Q_TILE 64 |
Functions | |
| void | mega_fuse_get_optimal_tiles (int *q_tile, int *kv_tile, int head_dim) |
| Get optimal tile sizes for current CPU. More... | |
| void | mega_fuse_report_stats (int hidden, int num_layers, int seq_len) |
| Report memory savings from mega-fusion. More... | |
| void | mega_fuse_rmsnorm_qkv (float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps) |
| Phase 1: Fused RMSNorm + QKV (intermediates in registers) More... | |
| void | mega_fuse_rmsnorm_qkv_rope (float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, const float *rope_cos, const float *rope_sin, int pos, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps) |
| Phase 2: Fused RMSNorm + QKV + RoPE. More... | |
| void | mega_fused_attention_decode (float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps) |
| Mega-fused attention for decode mode (single token) More... | |
| void | mega_fused_attention_prefill (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch) |
| Mega-fused attention for prefill mode (multiple tokens) More... | |
| void | mega_fused_attention_prefill_q8_0 (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch) |
| Mega-fused prefill attention kernel (Q8_0 out-proj) More... | |
| size_t | mega_fused_attention_prefill_q8_0_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| Get scratch buffer size for mega_fused_attention_prefill_q8_0. More... | |
| size_t | mega_fused_attention_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| Get scratch buffer size for mega_fused_attention_prefill. More... | |
| void | mega_fused_outproj_mlp_prefill (float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch) |
| Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill. More... | |
| size_t | mega_fused_outproj_mlp_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim) |
| Get scratch buffer size for mega_fused_outproj_mlp_prefill. More... | |
Mega-Fused Attention Kernel.
Holy grail fusion: RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
All intermediates stay in registers/L1/L2. Single DRAM round-trip.
Memory Reduction: Before: ~32KB intermediates per layer (stack/heap) After: ~8KB total (input + output only) Reduction: 4-5× per layer, ~100× for full model
Performance Target: Move from memory-bound to compute-bound Expected speedup: 5-10× for attention-heavy workloads
Definition in file mega_fused_attention.h.
| #define MEGA_FUSE_KV_TILE 64 |
Definition at line 36 of file mega_fused_attention.h.
| #define MEGA_FUSE_Q_TILE 64 |
Definition at line 32 of file mega_fused_attention.h.
| void mega_fuse_get_optimal_tiles | ( | int * | q_tile, |
| int * | kv_tile, | ||
| int | head_dim | ||
| ) |
Get optimal tile sizes for current CPU.
| void mega_fuse_report_stats | ( | int | hidden, |
| int | num_layers, | ||
| int | seq_len | ||
| ) |
Report memory savings from mega-fusion.
| void mega_fuse_rmsnorm_qkv | ( | float * | q_out, |
| float * | k_out, | ||
| float * | v_out, | ||
| const float * | input, | ||
| const float * | gamma, | ||
| const float * | W_qkv, | ||
| const float * | b_qkv, | ||
| int | hidden, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| float | eps | ||
| ) |
Phase 1: Fused RMSNorm + QKV (intermediates in registers)
Simpler step: Just fuse RMSNorm with QKV projection. Q/K/V stay in stack buffers, not DRAM.
| void mega_fuse_rmsnorm_qkv_rope | ( | float * | q_out, |
| float * | k_out, | ||
| float * | v_out, | ||
| const float * | input, | ||
| const float * | gamma, | ||
| const float * | W_qkv, | ||
| const float * | b_qkv, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | hidden, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | max_seq, | ||
| float | eps | ||
| ) |
Phase 2: Fused RMSNorm + QKV + RoPE.
Q/K stay in output buffers, RoPE applied in-place.
| void mega_fused_attention_decode | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | ln1_gamma, | ||
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| const float * | wo, | ||
| const float * | bo, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | cache_capacity, | ||
| float | eps | ||
| ) |
Mega-fused attention for decode mode (single token)
This is the "holy grail" - all operations fused, no intermediates to DRAM.
| output | Output [aligned_embed_dim] (includes residual add) |
| input | Input [aligned_embed_dim] |
| residual | Residual input [aligned_embed_dim] (or NULL) |
| ln1_gamma | RMSNorm gamma [embed_dim] |
| wq | Q weights (quantized) [num_heads * aligned_head_dim * aligned_embed_dim] |
| bq | Q bias [num_heads * aligned_head_dim] (or NULL) |
| wq_dt | Q weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32) |
| wk | K weights (quantized) [num_kv_heads * aligned_head_dim * aligned_embed_dim] |
| bk | K bias [num_kv_heads * aligned_head_dim] (or NULL) |
| wk_dt | K weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32) |
| wv | V weights (quantized) [num_kv_heads * aligned_head_dim * aligned_embed_dim] |
| bv | V bias [num_kv_heads * aligned_head_dim] (or NULL) |
| wv_dt | V weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32) |
| wo | Output projection weights (quantized) [aligned_embed_dim * aligned_embed_dim] |
| bo | Output bias [aligned_embed_dim] (or NULL) |
| wo_dt | Output weight dtype (CK_DT_Q5_0/CK_DT_FP32) |
| kv_cache_k | KV cache for K [num_kv_heads * cache_capacity * aligned_head_dim] |
| kv_cache_v | KV cache for V [num_kv_heads * cache_capacity * aligned_head_dim] |
| rope_cos | RoPE cos [max_seq, head_dim/2] |
| rope_sin | RoPE sin [max_seq, head_dim/2] |
| pos | Current position in sequence |
| embed_dim | Model hidden dimension (unpadded) |
| aligned_embed_dim | Aligned hidden dimension |
| num_heads | Number of attention heads |
| num_kv_heads | Number of KV heads (for GQA) |
| head_dim | Head dimension (unpadded) |
| aligned_head_dim | Aligned head dimension |
| cache_capacity | KV cache capacity (stride in tokens) |
| eps | RMSNorm epsilon |
| scratch | Scratch buffer from mega_fused_attention_prefill_scratch_size() |
Mega-fused attention for decode mode (single token)
RMSNorm → QKV → RoPE → Flash Attn → OutProj + Residual
Definition at line 589 of file mega_fused_attention_avx.c.
References kv_cache_write_head_major(), mega_fuse_flash_attention_avx(), mega_fuse_output_proj_residual(), mega_fuse_rmsnorm_qkv_avx(), mega_fuse_rope_inplace_avx(), and MEGA_STACK_MAX.
| void mega_fused_attention_prefill | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | ln1_gamma, | ||
| const void * | wq, | ||
| const float * | bq, | ||
| CKDataType | wq_dt, | ||
| const void * | wk, | ||
| const float * | bk, | ||
| CKDataType | wk_dt, | ||
| const void * | wv, | ||
| const float * | bv, | ||
| CKDataType | wv_dt, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| CKDataType | wo_dt, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | start_pos, | ||
| int | tokens, | ||
| int | cache_capacity, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Mega-fused attention for prefill mode (multiple tokens)
| output | Output [tokens, aligned_embed_dim] (includes residual add) |
| input | Input [tokens, aligned_embed_dim] |
| residual | Residual input [tokens, aligned_embed_dim] (or NULL) |
| ln1_gamma | RMSNorm gamma [embed_dim] |
| wq | Q weights [num_heads * aligned_head_dim * aligned_embed_dim] |
| bq | Q bias [num_heads * aligned_head_dim] (or NULL) |
| wk | K weights [num_kv_heads * aligned_head_dim * aligned_embed_dim] |
| bk | K bias [num_kv_heads * aligned_head_dim] (or NULL) |
| wv | V weights [num_kv_heads * aligned_head_dim * aligned_embed_dim] |
| bv | V bias [num_kv_heads * aligned_head_dim] (or NULL) |
| wo | Output projection weights [num_heads * aligned_embed_dim * aligned_head_dim] |
| bo | Output bias [aligned_embed_dim] (or NULL) |
| kv_cache_k | KV cache for K [num_kv_heads * cache_capacity * aligned_head_dim] |
| kv_cache_v | KV cache for V [num_kv_heads * cache_capacity * aligned_head_dim] |
| rope_cos | RoPE cos [max_seq, head_dim/2] |
| rope_sin | RoPE sin [max_seq, head_dim/2] |
| start_pos | Starting position in KV cache |
| tokens | Number of tokens to process |
| cache_capacity | KV cache capacity (stride in tokens) |
| embed_dim | Model hidden dimension (unpadded) |
| aligned_embed_dim | Aligned hidden dimension |
| num_heads | Number of attention heads |
| num_kv_heads | Number of KV heads |
| head_dim | Head dimension (unpadded) |
| aligned_head_dim | Aligned head dimension |
| eps | RMSNorm epsilon |
Definition at line 160 of file mega_fused_attention_prefill.c.
References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q5_0, CK_DT_Q8_0, ck_gemm_nt_head_major_q5_0(), ck_gemm_nt_head_major_q8_0(), ck_gemm_nt_quant(), ck_q8_0_outproj_enabled(), ck_residual_add_token_major(), flatten_head_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q5_0_q8_0(), QK5_0, QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().
| void mega_fused_attention_prefill_q8_0 | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | ln1_gamma, | ||
| const void * | wq, | ||
| const float * | bq, | ||
| CKDataType | wq_dt, | ||
| const void * | wk, | ||
| const float * | bk, | ||
| CKDataType | wk_dt, | ||
| const void * | wv, | ||
| const float * | bv, | ||
| CKDataType | wv_dt, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| CKDataType | wo_dt, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | start_pos, | ||
| int | tokens, | ||
| int | cache_capacity, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Mega-fused prefill attention kernel (Q8_0 out-proj)
Same layout and scratch requirements as mega_fused_attention_prefill.
Definition at line 105 of file mega_fused_attention_prefill_q8_0.c.
References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q8_0, ck_residual_add_token_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q8_0_q8_0(), QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().
| size_t mega_fused_attention_prefill_q8_0_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
Definition at line 84 of file mega_fused_attention_prefill_q8_0.c.
References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().
| size_t mega_fused_attention_prefill_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Get scratch buffer size for mega_fused_attention_prefill.
Definition at line 139 of file mega_fused_attention_prefill.c.
References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().
| void mega_fused_outproj_mlp_prefill | ( | float * | output, |
| const float * | attn_out, | ||
| const float * | residual, | ||
| const float * | ln2_gamma, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| CKDataType | wo_dt, | ||
| const void * | w1, | ||
| const float * | b1, | ||
| CKDataType | w1_dt, | ||
| const void * | w2, | ||
| const float * | b2, | ||
| CKDataType | w2_dt, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim, | ||
| int | intermediate_dim, | ||
| int | aligned_intermediate_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
Uses head-major attention output and quantized out-proj (Q5_0/Q8_0 weights).
Definition at line 184 of file mega_fused_outproj_mlp_prefill.c.
References add_inplace_f32(), align_up_size(), CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q6_K, CK_DT_Q8_0, ck_dtype_row_bytes(), fused_mlp_swiglu_prefill_w1w2_quant(), out_proj_head_major_q5_0_q8_0(), out_proj_head_major_q8_0_q8_0(), QK_K, quantize_attn_out_head_major_q8_0(), and rmsnorm_forward().
| size_t mega_fused_outproj_mlp_prefill_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim, | ||
| int | aligned_intermediate_dim | ||
| ) |
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
Definition at line 159 of file mega_fused_outproj_mlp_prefill.c.
References align_up_size(), CK_DT_Q8_0, ck_dtype_row_bytes(), and fused_mlp_swiglu_prefill_w1w2_quant_scratch_size().