← Back to C-Kernel-Engine Docs Doxygen Source Documentation
fused_kernels.h
Go to the documentation of this file.
1 /**
2  * @file fused_kernels.h
3  * @brief Fused Kernel API for Cache-Aware Attention Fusion
4  *
5  * Key design principles:
6  * 1. Kernels take output buffer as parameter (no internal malloc)
7  * 2. Buffers are designed to fit in L1/L2 cache
8  * 3. Kernels can chain: rmsnorm → QKV → RoPE → Flash → OutProj
9  * 4. Per-head parallelization with L1 cache constraints
10  *
11  * Cache hierarchy (typical Xeon):
12  * - Registers: 32 ZMM × 64B = 2KB (AVX-512)
13  * - L1: 48KB per core
14  * - L2: 1-2MB per core
15  * - L3: 1.5MB × cores (shared)
16  *
17  * Per-head working set must fit in L1:
18  * Q_h (128B) + K_tile (8KB) + V_tile (8KB) + O_h (256B) = ~16.8KB
19  *
20  * Reference: docs/site/assets/per_head_fusion_math.svg
21  */
22 
23 #ifndef FUSED_KERNELS_H
24 #define FUSED_KERNELS_H
25 
26 #include <stdint.h>
27 
28 /*============================================================================
29  * Cache-Aware Configuration
30  *============================================================================*/
31 
32 /**
33  * @brief Compute optimal KV tile size for flash attention
34  *
35  * Formula: T_kv ≤ (S_L1 × 0.8 - 2×d×B) / (2×d×B)
36  *
37  * Uses 80% of L1 to leave room for prefetch and OS.
38  *
39  * @param l1_size L1 cache size in bytes (typically 49152 for 48KB)
40  * @param head_dim Head dimension
41  * @param bytes_per_elem Element size (2 for FP16, 4 for FP32)
42  * @return Optimal KV tile size (multiple of 64 for cache line alignment)
43  */
45  int l1_size,
46  int head_dim,
47  int bytes_per_elem
48 );
49 
50 /*============================================================================
51  * Fused RMSNorm
52  *============================================================================*/
53 
54 /**
55  * @brief Fused RMSNorm - writes to pre-allocated buffer
56  *
57  * Takes input, applies RMSNorm, writes to output buffer.
58  * Buffer should be in L1/L2 for best performance.
59  *
60  * @param input Input tensor [hidden]
61  * @param gamma Gamma parameter [hidden]
62  * @param beta Beta parameter [hidden] or NULL (RMSNorm doesn't use beta typically)
63  * @param output Output buffer [hidden] - pre-allocated, caller owns
64  * @param hidden Hidden dimension
65  * @param eps Epsilon for numerical stability
66  */
68  const float *input,
69  const float *gamma,
70  const float *beta,
71  float *output,
72  int hidden,
73  float eps
74 );
75 
76 /**
77  * @brief Fused RMSNorm with fused QKV projection
78  *
79  * Computes RMSNorm and immediately projects to Q, K, V.
80  * No intermediate buffer - Q/K/V written directly to output buffers.
81  *
82  * @param input Input tensor [hidden]
83  * @param gamma RMSNorm gamma [hidden]
84  * @param W_qkv QKV weight matrix [3*hidden, hidden]
85  * @param b_qkv QKV bias [3*hidden] or NULL
86  * @param q_out Output buffer for Q [num_heads * head_dim]
87  * @param k_out Output buffer for K [num_kv_heads * head_dim]
88  * @param v_out Output buffer for V [num_kv_heads * head_dim]
89  * @param hidden Hidden dimension
90  * @param num_heads Number of attention heads
91  * @param num_kv_heads Number of KV heads
92  * @param head_dim Head dimension
93  * @param eps RMSNorm epsilon
94  */
96  const float *input,
97  const float *gamma,
98  const float *W_qkv,
99  const float *b_qkv,
100  float *q_out,
101  float *k_out,
102  float *v_out,
103  int hidden,
104  int num_heads,
105  int num_kv_heads,
106  int head_dim,
107  float eps
108 );
109 
110 /*============================================================================
111  * Fused RoPE
112  *============================================================================*/
113 
114 /**
115  * @brief Fused RoPE application (in-place on pre-allocated buffers)
116  *
117  * Applies RoPE rotation to Q and K in their buffers.
118  * Buffers stay in cache - no extra memory allocation.
119  *
120  * @param q Q tensor [num_heads * head_dim] - modified in place
121  * @param k K tensor [num_kv_heads * head_dim] - modified in place
122  * @param rope_cos RoPE cos table [max_seq, head_dim/2]
123  * @param rope_sin RoPE sin table [max_seq, head_dim/2]
124  * @param pos Current position in sequence
125  * @param num_heads Number of Q heads
126  * @param num_kv_heads Number of K/V heads
127  * @param head_dim Head dimension
128  * @param max_seq Maximum sequence length
129  */
131  float *q,
132  float *k,
133  const float *rope_cos,
134  const float *rope_sin,
135  int pos,
136  int num_heads,
137  int num_kv_heads,
138  int head_dim,
139  int max_seq
140 );
141 
142 /*============================================================================
143  * Fused Flash Attention
144  *============================================================================*/
145 
146 /**
147  * @brief Fused Flash Attention for single head
148  *
149  * Online softmax with streaming KV tiles.
150  * O, m, l stay in registers throughout.
151  *
152  * @param o_out Output buffer [head_dim] - pre-allocated
153  * @param q Q vector for this head [head_dim]
154  * @param kv_cache_k KV cache K [seq_len * num_kv_heads * head_dim]
155  * @param kv_cache_v KV cache V [seq_len * num_kv_heads * head_dim]
156  * @param kv_head_idx Which KV head this Q head uses
157  * @param seq_len Current sequence length
158  * @param head_dim Head dimension
159  * @param kv_tile_size Tile size for KV streaming
160  */
162  float *o_out,
163  const float *q,
164  const float *kv_cache_k,
165  const float *kv_cache_v,
166  int kv_head_idx,
167  int seq_len,
168  int head_dim,
169  int kv_tile_size
170 );
171 
172 /**
173  * @brief Fused Flash Attention for all heads (parallel dispatch)
174  *
175  * Dispatches per-head attention to parallel cores.
176  * Each head's working set fits in L1.
177  *
178  * @param o_out Output buffer [num_heads * head_dim]
179  * @param q_all Q tensor for all heads [num_heads * head_dim]
180  * @param kv_cache_k KV cache K [seq_len * num_kv_heads * head_dim]
181  * @param kv_cache_v KV cache V [seq_len * num_kv_heads * head_dim]
182  * @param num_heads Number of attention heads
183  * @param num_kv_heads Number of KV heads
184  * @param head_dim Head dimension
185  * @param seq_len Current sequence length
186  * @param kv_tile_size Tile size for KV streaming
187  */
189  float *o_out,
190  const float *q_all,
191  const float *kv_cache_k,
192  const float *kv_cache_v,
193  int num_heads,
194  int num_kv_heads,
195  int head_dim,
196  int seq_len,
197  int kv_tile_size
198 );
199 
200 /*============================================================================
201  * Fused Output Projection
202  *============================================================================*/
203 
204 /**
205  * @brief Fused output projection with residual add
206  *
207  * Computes O @ W_o + residual, all in registers/L1.
208  * Final store includes residual add.
209  *
210  * @param output Output buffer [hidden] - final DRAM write
211  * @param o_all Concatenated O from all heads [hidden]
212  * @param W_o Output projection weights [hidden, hidden]
213  * @param b_o Output bias [hidden] or NULL
214  * @param residual Residual input [hidden]
215  * @param hidden Hidden dimension
216  * @param num_heads Number of attention heads
217  * @param head_dim Head dimension
218  */
220  float *output,
221  const float *o_all,
222  const float *W_o,
223  const float *b_o,
224  const float *residual,
225  int hidden,
226  int num_heads,
227  int head_dim
228 );
229 
230 /*============================================================================
231  * Complete Mega-Fused Attention
232  *============================================================================*/
233 
234 /**
235  * @brief Complete mega-fused attention block
236  *
237  * RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
238  *
239  * All intermediates in L1/L2/registers.
240  * Single DRAM round-trip: 4KB in + 4KB out.
241  *
242  * @param output Output tensor [hidden] - single DRAM write
243  * @param input Input tensor [hidden] - single DRAM read
244  * @param residual Residual input [hidden]
245  * @param W_qkv QKV weights [3*hidden, hidden]
246  * @param b_qkv QKV bias [3*hidden] or NULL
247  * @param W_o Output projection [hidden, hidden]
248  * @param b_o Output bias [hidden] or NULL
249  * @param kv_cache_k KV cache K [seq, hidden] - updated in place
250  * @param kv_cache_v KV cache V [seq, hidden] - updated in place
251  * @param rope_cos RoPE cos [max_seq, head_dim/2]
252  * @param rope_sin RoPE sin [max_seq, head_dim/2]
253  * @param pos Current position
254  * @param seq_len Current sequence length
255  * @param hidden Hidden dimension
256  * @param num_heads Number of heads
257  * @param num_kv_heads Number of KV heads
258  * @param head_dim Head dimension
259  * @param max_seq Maximum sequence length
260  * @param eps RMSNorm epsilon
261  */
263  float *output,
264  const float *input,
265  const float *residual,
266  const float *W_qkv,
267  const float *b_qkv,
268  const float *W_o,
269  const float *b_o,
270  float *kv_cache_k,
271  float *kv_cache_v,
272  const float *rope_cos,
273  const float *rope_sin,
274  int pos,
275  int seq_len,
276  int hidden,
277  int num_heads,
278  int num_kv_heads,
279  int head_dim,
280  int max_seq,
281  float eps
282 );
283 
284 /*============================================================================
285  * Statistics and Validation
286  *============================================================================*/
287 
288 /**
289  * @brief Report memory savings from mega-fusion
290  *
291  * @param hidden Hidden dimension
292  * @param num_layers Number of layers
293  * @param seq_len Sequence length
294  */
296  int hidden,
297  int num_layers,
298  int seq_len
299 );
300 
301 /**
302  * @brief Validate cache constraints for fusion
303  *
304  * @param l1_size L1 cache size
305  * @param head_dim Head dimension
306  * @param kv_tile_size KV tile size
307  * @param bytes_per_elem Element size
308  * @return 0 if valid, -1 if working set exceeds cache
309  */
311  int l1_size,
312  int head_dim,
313  int kv_tile_size,
314  int bytes_per_elem
315 );
316 
317 #endif /* FUSED_KERNELS_H */
int fused_kernels_validate_constraints(int l1_size, int head_dim, int kv_tile_size, int bytes_per_elem)
Validate cache constraints for fusion.
void fused_rope_inplace(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int max_seq)
Fused RoPE application (in-place on pre-allocated buffers)
void fused_output_projection_residual(float *output, const float *o_all, const float *W_o, const float *b_o, const float *residual, int hidden, int num_heads, int head_dim)
Fused output projection with residual add.
void fused_flash_attention_all_heads(float *o_out, const float *q_all, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int head_dim, int seq_len, int kv_tile_size)
Fused Flash Attention for all heads (parallel dispatch)
void fused_kernels_report_stats(int hidden, int num_layers, int seq_len)
Report memory savings from mega-fusion.
void fused_rmsnorm_qkv(const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, float *q_out, float *k_out, float *v_out, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
Fused RMSNorm with fused QKV projection.
void fused_flash_attention_head(float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int kv_head_idx, int seq_len, int head_dim, int kv_tile_size)
Fused Flash Attention for single head.
int fused_kernels_compute_kv_tile(int l1_size, int head_dim, int bytes_per_elem)
Compute optimal KV tile size for flash attention.
void fused_rmsnorm(const float *input, const float *gamma, const float *beta, float *output, int hidden, float eps)
Fused RMSNorm - writes to pre-allocated buffer.
void mega_fused_attention(float *output, const float *input, const float *residual, const float *W_qkv, const float *b_qkv, const float *W_o, const float *b_o, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int seq_len, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
Complete mega-fused attention block.