Fused Kernel API for Cache-Aware Attention Fusion. More...
#include <stdint.h>Go to the source code of this file.
Functions | |
| void | fused_flash_attention_all_heads (float *o_out, const float *q_all, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int head_dim, int seq_len, int kv_tile_size) |
| Fused Flash Attention for all heads (parallel dispatch) More... | |
| void | fused_flash_attention_head (float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int kv_head_idx, int seq_len, int head_dim, int kv_tile_size) |
| Fused Flash Attention for single head. More... | |
| int | fused_kernels_compute_kv_tile (int l1_size, int head_dim, int bytes_per_elem) |
| Compute optimal KV tile size for flash attention. More... | |
| void | fused_kernels_report_stats (int hidden, int num_layers, int seq_len) |
| Report memory savings from mega-fusion. More... | |
| int | fused_kernels_validate_constraints (int l1_size, int head_dim, int kv_tile_size, int bytes_per_elem) |
| Validate cache constraints for fusion. More... | |
| void | fused_output_projection_residual (float *output, const float *o_all, const float *W_o, const float *b_o, const float *residual, int hidden, int num_heads, int head_dim) |
| Fused output projection with residual add. More... | |
| void | fused_rmsnorm (const float *input, const float *gamma, const float *beta, float *output, int hidden, float eps) |
| Fused RMSNorm - writes to pre-allocated buffer. More... | |
| void | fused_rmsnorm_qkv (const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, float *q_out, float *k_out, float *v_out, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps) |
| Fused RMSNorm with fused QKV projection. More... | |
| void | fused_rope_inplace (float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int max_seq) |
| Fused RoPE application (in-place on pre-allocated buffers) More... | |
| void | mega_fused_attention (float *output, const float *input, const float *residual, const float *W_qkv, const float *b_qkv, const float *W_o, const float *b_o, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int seq_len, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps) |
| Complete mega-fused attention block. More... | |
Fused Kernel API for Cache-Aware Attention Fusion.
Key design principles:
Cache hierarchy (typical Xeon):
Per-head working set must fit in L1: Q_h (128B) + K_tile (8KB) + V_tile (8KB) + O_h (256B) = ~16.8KB
Reference: docs/site/assets/per_head_fusion_math.svg
Definition in file fused_kernels.h.
| void fused_flash_attention_all_heads | ( | float * | o_out, |
| const float * | q_all, | ||
| const float * | kv_cache_k, | ||
| const float * | kv_cache_v, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | seq_len, | ||
| int | kv_tile_size | ||
| ) |
Fused Flash Attention for all heads (parallel dispatch)
Dispatches per-head attention to parallel cores. Each head's working set fits in L1.
| o_out | Output buffer [num_heads * head_dim] |
| q_all | Q tensor for all heads [num_heads * head_dim] |
| kv_cache_k | KV cache K [seq_len * num_kv_heads * head_dim] |
| kv_cache_v | KV cache V [seq_len * num_kv_heads * head_dim] |
| num_heads | Number of attention heads |
| num_kv_heads | Number of KV heads |
| head_dim | Head dimension |
| seq_len | Current sequence length |
| kv_tile_size | Tile size for KV streaming |
| void fused_flash_attention_head | ( | float * | o_out, |
| const float * | q, | ||
| const float * | kv_cache_k, | ||
| const float * | kv_cache_v, | ||
| int | kv_head_idx, | ||
| int | seq_len, | ||
| int | head_dim, | ||
| int | kv_tile_size | ||
| ) |
Fused Flash Attention for single head.
Online softmax with streaming KV tiles. O, m, l stay in registers throughout.
| o_out | Output buffer [head_dim] - pre-allocated |
| q | Q vector for this head [head_dim] |
| kv_cache_k | KV cache K [seq_len * num_kv_heads * head_dim] |
| kv_cache_v | KV cache V [seq_len * num_kv_heads * head_dim] |
| kv_head_idx | Which KV head this Q head uses |
| seq_len | Current sequence length |
| head_dim | Head dimension |
| kv_tile_size | Tile size for KV streaming |
| int fused_kernels_compute_kv_tile | ( | int | l1_size, |
| int | head_dim, | ||
| int | bytes_per_elem | ||
| ) |
Compute optimal KV tile size for flash attention.
Formula: T_kv ≤ (S_L1 × 0.8 - 2×d×B) / (2×d×B)
Uses 80% of L1 to leave room for prefetch and OS.
| l1_size | L1 cache size in bytes (typically 49152 for 48KB) |
| head_dim | Head dimension |
| bytes_per_elem | Element size (2 for FP16, 4 for FP32) |
| void fused_kernels_report_stats | ( | int | hidden, |
| int | num_layers, | ||
| int | seq_len | ||
| ) |
Report memory savings from mega-fusion.
| hidden | Hidden dimension |
| num_layers | Number of layers |
| seq_len | Sequence length |
| int fused_kernels_validate_constraints | ( | int | l1_size, |
| int | head_dim, | ||
| int | kv_tile_size, | ||
| int | bytes_per_elem | ||
| ) |
Validate cache constraints for fusion.
| l1_size | L1 cache size |
| head_dim | Head dimension |
| kv_tile_size | KV tile size |
| bytes_per_elem | Element size |
| void fused_output_projection_residual | ( | float * | output, |
| const float * | o_all, | ||
| const float * | W_o, | ||
| const float * | b_o, | ||
| const float * | residual, | ||
| int | hidden, | ||
| int | num_heads, | ||
| int | head_dim | ||
| ) |
Fused output projection with residual add.
Computes O @ W_o + residual, all in registers/L1. Final store includes residual add.
| output | Output buffer [hidden] - final DRAM write |
| o_all | Concatenated O from all heads [hidden] |
| W_o | Output projection weights [hidden, hidden] |
| b_o | Output bias [hidden] or NULL |
| residual | Residual input [hidden] |
| hidden | Hidden dimension |
| num_heads | Number of attention heads |
| head_dim | Head dimension |
| void fused_rmsnorm | ( | const float * | input, |
| const float * | gamma, | ||
| const float * | beta, | ||
| float * | output, | ||
| int | hidden, | ||
| float | eps | ||
| ) |
Fused RMSNorm - writes to pre-allocated buffer.
Takes input, applies RMSNorm, writes to output buffer. Buffer should be in L1/L2 for best performance.
| input | Input tensor [hidden] |
| gamma | Gamma parameter [hidden] |
| beta | Beta parameter [hidden] or NULL (RMSNorm doesn't use beta typically) |
| output | Output buffer [hidden] - pre-allocated, caller owns |
| hidden | Hidden dimension |
| eps | Epsilon for numerical stability |
| void fused_rmsnorm_qkv | ( | const float * | input, |
| const float * | gamma, | ||
| const float * | W_qkv, | ||
| const float * | b_qkv, | ||
| float * | q_out, | ||
| float * | k_out, | ||
| float * | v_out, | ||
| int | hidden, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| float | eps | ||
| ) |
Fused RMSNorm with fused QKV projection.
Computes RMSNorm and immediately projects to Q, K, V. No intermediate buffer - Q/K/V written directly to output buffers.
| input | Input tensor [hidden] |
| gamma | RMSNorm gamma [hidden] |
| W_qkv | QKV weight matrix [3*hidden, hidden] |
| b_qkv | QKV bias [3*hidden] or NULL |
| q_out | Output buffer for Q [num_heads * head_dim] |
| k_out | Output buffer for K [num_kv_heads * head_dim] |
| v_out | Output buffer for V [num_kv_heads * head_dim] |
| hidden | Hidden dimension |
| num_heads | Number of attention heads |
| num_kv_heads | Number of KV heads |
| head_dim | Head dimension |
| eps | RMSNorm epsilon |
| void fused_rope_inplace | ( | float * | q, |
| float * | k, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | max_seq | ||
| ) |
Fused RoPE application (in-place on pre-allocated buffers)
Applies RoPE rotation to Q and K in their buffers. Buffers stay in cache - no extra memory allocation.
| q | Q tensor [num_heads * head_dim] - modified in place |
| k | K tensor [num_kv_heads * head_dim] - modified in place |
| rope_cos | RoPE cos table [max_seq, head_dim/2] |
| rope_sin | RoPE sin table [max_seq, head_dim/2] |
| pos | Current position in sequence |
| num_heads | Number of Q heads |
| num_kv_heads | Number of K/V heads |
| head_dim | Head dimension |
| max_seq | Maximum sequence length |
| void mega_fused_attention | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | W_qkv, | ||
| const float * | b_qkv, | ||
| const float * | W_o, | ||
| const float * | b_o, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | max_seq, | ||
| float | eps | ||
| ) |
Complete mega-fused attention block.
RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
All intermediates in L1/L2/registers. Single DRAM round-trip: 4KB in + 4KB out.
| output | Output tensor [hidden] - single DRAM write |
| input | Input tensor [hidden] - single DRAM read |
| residual | Residual input [hidden] |
| W_qkv | QKV weights [3*hidden, hidden] |
| b_qkv | QKV bias [3*hidden] or NULL |
| W_o | Output projection [hidden, hidden] |
| b_o | Output bias [hidden] or NULL |
| kv_cache_k | KV cache K [seq, hidden] - updated in place |
| kv_cache_v | KV cache V [seq, hidden] - updated in place |
| rope_cos | RoPE cos [max_seq, head_dim/2] |
| rope_sin | RoPE sin [max_seq, head_dim/2] |
| pos | Current position |
| seq_len | Current sequence length |
| hidden | Hidden dimension |
| num_heads | Number of heads |
| num_kv_heads | Number of KV heads |
| head_dim | Head dimension |
| max_seq | Maximum sequence length |
| eps | RMSNorm epsilon |