← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention.h
Go to the documentation of this file.
1 /**
2  * @file mega_fused_attention.h
3  * @brief Mega-Fused Attention Kernel
4  *
5  * Holy grail fusion: RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
6  *
7  * All intermediates stay in registers/L1/L2. Single DRAM round-trip.
8  *
9  * Memory Reduction:
10  * Before: ~32KB intermediates per layer (stack/heap)
11  * After: ~8KB total (input + output only)
12  * Reduction: 4-5× per layer, ~100× for full model
13  *
14  * Performance Target:
15  * Move from memory-bound to compute-bound
16  * Expected speedup: 5-10× for attention-heavy workloads
17  */
18 
19 #ifndef MEGA_FUSED_ATTENTION_H
20 #define MEGA_FUSED_ATTENTION_H
21 
22 #include <stdint.h>
23 
24 #include "ckernel_dtype.h"
25 
26 /*============================================================================
27  * Configuration
28  *============================================================================*/
29 
30 /* Tile sizes for streaming through cache hierarchy */
31 #ifndef MEGA_FUSE_Q_TILE
32 #define MEGA_FUSE_Q_TILE 64
33 #endif
34 
35 #ifndef MEGA_FUSE_KV_TILE
36 #define MEGA_FUSE_KV_TILE 64
37 #endif
38 
39 /*============================================================================
40  * Mega-Fused Attention API
41  *============================================================================*/
42 
43 /**
44  * @brief Mega-fused attention for decode mode (single token)
45  *
46  * This is the "holy grail" - all operations fused, no intermediates to DRAM.
47  *
48  * @param output Output [aligned_embed_dim] (includes residual add)
49  * @param input Input [aligned_embed_dim]
50  * @param residual Residual input [aligned_embed_dim] (or NULL)
51  * @param ln1_gamma RMSNorm gamma [embed_dim]
52  * @param wq Q weights (quantized) [num_heads * aligned_head_dim * aligned_embed_dim]
53  * @param bq Q bias [num_heads * aligned_head_dim] (or NULL)
54  * @param wq_dt Q weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32)
55  * @param wk K weights (quantized) [num_kv_heads * aligned_head_dim * aligned_embed_dim]
56  * @param bk K bias [num_kv_heads * aligned_head_dim] (or NULL)
57  * @param wk_dt K weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32)
58  * @param wv V weights (quantized) [num_kv_heads * aligned_head_dim * aligned_embed_dim]
59  * @param bv V bias [num_kv_heads * aligned_head_dim] (or NULL)
60  * @param wv_dt V weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32)
61  * @param wo Output projection weights (quantized) [aligned_embed_dim * aligned_embed_dim]
62  * @param bo Output bias [aligned_embed_dim] (or NULL)
63  * @param wo_dt Output weight dtype (CK_DT_Q5_0/CK_DT_FP32)
64  * @param kv_cache_k KV cache for K [num_kv_heads * cache_capacity * aligned_head_dim]
65  * @param kv_cache_v KV cache for V [num_kv_heads * cache_capacity * aligned_head_dim]
66  * @param rope_cos RoPE cos [max_seq, head_dim/2]
67  * @param rope_sin RoPE sin [max_seq, head_dim/2]
68  * @param pos Current position in sequence
69  * @param embed_dim Model hidden dimension (unpadded)
70  * @param aligned_embed_dim Aligned hidden dimension
71  * @param num_heads Number of attention heads
72  * @param num_kv_heads Number of KV heads (for GQA)
73  * @param head_dim Head dimension (unpadded)
74  * @param aligned_head_dim Aligned head dimension
75  * @param cache_capacity KV cache capacity (stride in tokens)
76  * @param eps RMSNorm epsilon
77  * @param scratch Scratch buffer from mega_fused_attention_prefill_scratch_size()
78  */
80  float *output,
81  const float *input,
82  const float *residual,
83  const float *ln1_gamma,
84  const float *wq, const float *bq,
85  const float *wk, const float *bk,
86  const float *wv, const float *bv,
87  const float *wo, const float *bo,
88  float *kv_cache_k,
89  float *kv_cache_v,
90  const float *rope_cos,
91  const float *rope_sin,
92  int pos,
93  int embed_dim,
94  int aligned_embed_dim,
95  int num_heads,
96  int num_kv_heads,
97  int head_dim,
98  int aligned_head_dim,
99  int cache_capacity,
100  float eps
101 );
102 
103 /**
104  * @brief Mega-fused attention for prefill mode (multiple tokens)
105  *
106  * @param output Output [tokens, aligned_embed_dim] (includes residual add)
107  * @param input Input [tokens, aligned_embed_dim]
108  * @param residual Residual input [tokens, aligned_embed_dim] (or NULL)
109  * @param ln1_gamma RMSNorm gamma [embed_dim]
110  * @param wq Q weights [num_heads * aligned_head_dim * aligned_embed_dim]
111  * @param bq Q bias [num_heads * aligned_head_dim] (or NULL)
112  * @param wk K weights [num_kv_heads * aligned_head_dim * aligned_embed_dim]
113  * @param bk K bias [num_kv_heads * aligned_head_dim] (or NULL)
114  * @param wv V weights [num_kv_heads * aligned_head_dim * aligned_embed_dim]
115  * @param bv V bias [num_kv_heads * aligned_head_dim] (or NULL)
116  * @param wo Output projection weights [num_heads * aligned_embed_dim * aligned_head_dim]
117  * @param bo Output bias [aligned_embed_dim] (or NULL)
118  * @param kv_cache_k KV cache for K [num_kv_heads * cache_capacity * aligned_head_dim]
119  * @param kv_cache_v KV cache for V [num_kv_heads * cache_capacity * aligned_head_dim]
120  * @param rope_cos RoPE cos [max_seq, head_dim/2]
121  * @param rope_sin RoPE sin [max_seq, head_dim/2]
122  * @param start_pos Starting position in KV cache
123  * @param tokens Number of tokens to process
124  * @param cache_capacity KV cache capacity (stride in tokens)
125  * @param embed_dim Model hidden dimension (unpadded)
126  * @param aligned_embed_dim Aligned hidden dimension
127  * @param num_heads Number of attention heads
128  * @param num_kv_heads Number of KV heads
129  * @param head_dim Head dimension (unpadded)
130  * @param aligned_head_dim Aligned head dimension
131  * @param eps RMSNorm epsilon
132  */
134  float *output,
135  const float *input,
136  const float *residual,
137  const float *ln1_gamma,
138  const void *wq, const float *bq, CKDataType wq_dt,
139  const void *wk, const float *bk, CKDataType wk_dt,
140  const void *wv, const float *bv, CKDataType wv_dt,
141  const void *wo, const float *bo, CKDataType wo_dt,
142  float *kv_cache_k,
143  float *kv_cache_v,
144  const float *rope_cos,
145  const float *rope_sin,
146  int start_pos,
147  int tokens,
148  int cache_capacity,
149  int embed_dim,
150  int aligned_embed_dim,
151  int num_heads,
152  int num_kv_heads,
153  int head_dim,
154  int aligned_head_dim,
155  float eps,
156  void *scratch
157 );
158 
159 /**
160  * @brief Mega-fused prefill attention kernel (Q8_0 out-proj)
161  *
162  * Same layout and scratch requirements as mega_fused_attention_prefill.
163  */
165  float *output,
166  const float *input,
167  const float *residual,
168  const float *ln1_gamma,
169  const void *wq, const float *bq, CKDataType wq_dt,
170  const void *wk, const float *bk, CKDataType wk_dt,
171  const void *wv, const float *bv, CKDataType wv_dt,
172  const void *wo, const float *bo, CKDataType wo_dt,
173  float *kv_cache_k,
174  float *kv_cache_v,
175  const float *rope_cos,
176  const float *rope_sin,
177  int start_pos,
178  int tokens,
179  int cache_capacity,
180  int embed_dim,
181  int aligned_embed_dim,
182  int num_heads,
183  int num_kv_heads,
184  int head_dim,
185  int aligned_head_dim,
186  float eps,
187  void *scratch
188 );
189 
190 /** @brief Get scratch buffer size for mega_fused_attention_prefill */
192  int aligned_embed_dim,
193  int num_heads,
194  int aligned_head_dim);
195 
196 /** @brief Get scratch buffer size for mega_fused_attention_prefill_q8_0 */
198  int aligned_embed_dim,
199  int num_heads,
200  int aligned_head_dim);
201 
202 /**
203  * @brief Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill
204  *
205  * Uses head-major attention output and quantized out-proj (Q5_0/Q8_0 weights).
206  */
208  float *output,
209  const float *attn_out,
210  const float *residual,
211  const float *ln2_gamma,
212  const void *wo, const float *bo, CKDataType wo_dt,
213  const void *w1, const float *b1, CKDataType w1_dt,
214  const void *w2, const float *b2, CKDataType w2_dt,
215  int tokens,
216  int embed_dim,
217  int aligned_embed_dim,
218  int num_heads,
219  int aligned_head_dim,
220  int intermediate_dim,
221  int aligned_intermediate_dim,
222  float eps,
223  void *scratch
224 );
225 
226 /** @brief Get scratch buffer size for mega_fused_outproj_mlp_prefill */
228  int aligned_embed_dim,
229  int num_heads,
230  int aligned_head_dim,
231  int aligned_intermediate_dim);
232 
233 /**
234  * @brief Phase 1: Fused RMSNorm + QKV (intermediates in registers)
235  *
236  * Simpler step: Just fuse RMSNorm with QKV projection.
237  * Q/K/V stay in stack buffers, not DRAM.
238  */
240  float *q_out, // [num_heads * head_dim]
241  float *k_out, // [num_kv_heads * head_dim]
242  float *v_out, // [num_kv_heads * head_dim]
243  const float *input, // [hidden]
244  const float *gamma, // [hidden]
245  const float *W_qkv,
246  const float *b_qkv,
247  int hidden,
248  int num_heads,
249  int num_kv_heads,
250  int head_dim,
251  float eps
252 );
253 
254 /**
255  * @brief Phase 2: Fused RMSNorm + QKV + RoPE
256  *
257  * Q/K stay in output buffers, RoPE applied in-place.
258  */
260  float *q_out,
261  float *k_out,
262  float *v_out,
263  const float *input,
264  const float *gamma,
265  const float *W_qkv,
266  const float *b_qkv,
267  const float *rope_cos,
268  const float *rope_sin,
269  int pos,
270  int hidden,
271  int num_heads,
272  int num_kv_heads,
273  int head_dim,
274  int max_seq,
275  float eps
276 );
277 
278 /**
279  * @brief Get optimal tile sizes for current CPU
280  */
282  int *q_tile, // Output: Q tile size
283  int *kv_tile, // Output: KV tile size
284  int head_dim
285 );
286 
287 /**
288  * @brief Report memory savings from mega-fusion
289  */
291  int hidden,
292  int num_layers,
293  int seq_len
294 );
295 
296 #endif /* MEGA_FUSED_ATTENTION_H */
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
size_t mega_fused_attention_prefill_q8_0_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
size_t mega_fused_attention_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill.
void mega_fused_attention_prefill(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused attention for prefill mode (multiple tokens)
void mega_fuse_report_stats(int hidden, int num_layers, int seq_len)
Report memory savings from mega-fusion.
void mega_fuse_rmsnorm_qkv_rope(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, const float *rope_cos, const float *rope_sin, int pos, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
Phase 2: Fused RMSNorm + QKV + RoPE.
void mega_fuse_get_optimal_tiles(int *q_tile, int *kv_tile, int head_dim)
Get optimal tile sizes for current CPU.
void mega_fused_attention_decode(float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
Mega-fused attention for decode mode (single token)
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
void mega_fused_attention_prefill_q8_0(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused prefill attention kernel (Q8_0 out-proj)
void mega_fuse_rmsnorm_qkv(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
Phase 1: Fused RMSNorm + QKV (intermediates in registers)