← Back to C-Kernel-Engine Docs Doxygen Source Documentation
fused_kernels.h File Reference

Fused Kernel API for Cache-Aware Attention Fusion. More...

#include <stdint.h>

Go to the source code of this file.

Functions

void fused_flash_attention_all_heads (float *o_out, const float *q_all, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int head_dim, int seq_len, int kv_tile_size)
 Fused Flash Attention for all heads (parallel dispatch) More...
 
void fused_flash_attention_head (float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int kv_head_idx, int seq_len, int head_dim, int kv_tile_size)
 Fused Flash Attention for single head. More...
 
int fused_kernels_compute_kv_tile (int l1_size, int head_dim, int bytes_per_elem)
 Compute optimal KV tile size for flash attention. More...
 
void fused_kernels_report_stats (int hidden, int num_layers, int seq_len)
 Report memory savings from mega-fusion. More...
 
int fused_kernels_validate_constraints (int l1_size, int head_dim, int kv_tile_size, int bytes_per_elem)
 Validate cache constraints for fusion. More...
 
void fused_output_projection_residual (float *output, const float *o_all, const float *W_o, const float *b_o, const float *residual, int hidden, int num_heads, int head_dim)
 Fused output projection with residual add. More...
 
void fused_rmsnorm (const float *input, const float *gamma, const float *beta, float *output, int hidden, float eps)
 Fused RMSNorm - writes to pre-allocated buffer. More...
 
void fused_rmsnorm_qkv (const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, float *q_out, float *k_out, float *v_out, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
 Fused RMSNorm with fused QKV projection. More...
 
void fused_rope_inplace (float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int max_seq)
 Fused RoPE application (in-place on pre-allocated buffers) More...
 
void mega_fused_attention (float *output, const float *input, const float *residual, const float *W_qkv, const float *b_qkv, const float *W_o, const float *b_o, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int seq_len, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
 Complete mega-fused attention block. More...
 

Detailed Description

Fused Kernel API for Cache-Aware Attention Fusion.

Key design principles:

  1. Kernels take output buffer as parameter (no internal malloc)
  2. Buffers are designed to fit in L1/L2 cache
  3. Kernels can chain: rmsnorm → QKV → RoPE → Flash → OutProj
  4. Per-head parallelization with L1 cache constraints

Cache hierarchy (typical Xeon):

  • Registers: 32 ZMM × 64B = 2KB (AVX-512)
  • L1: 48KB per core
  • L2: 1-2MB per core
  • L3: 1.5MB × cores (shared)

Per-head working set must fit in L1: Q_h (128B) + K_tile (8KB) + V_tile (8KB) + O_h (256B) = ~16.8KB

Reference: docs/site/assets/per_head_fusion_math.svg

Definition in file fused_kernels.h.

Function Documentation

◆ fused_flash_attention_all_heads()

void fused_flash_attention_all_heads ( float *  o_out,
const float *  q_all,
const float *  kv_cache_k,
const float *  kv_cache_v,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  seq_len,
int  kv_tile_size 
)

Fused Flash Attention for all heads (parallel dispatch)

Dispatches per-head attention to parallel cores. Each head's working set fits in L1.

Parameters
o_outOutput buffer [num_heads * head_dim]
q_allQ tensor for all heads [num_heads * head_dim]
kv_cache_kKV cache K [seq_len * num_kv_heads * head_dim]
kv_cache_vKV cache V [seq_len * num_kv_heads * head_dim]
num_headsNumber of attention heads
num_kv_headsNumber of KV heads
head_dimHead dimension
seq_lenCurrent sequence length
kv_tile_sizeTile size for KV streaming

◆ fused_flash_attention_head()

void fused_flash_attention_head ( float *  o_out,
const float *  q,
const float *  kv_cache_k,
const float *  kv_cache_v,
int  kv_head_idx,
int  seq_len,
int  head_dim,
int  kv_tile_size 
)

Fused Flash Attention for single head.

Online softmax with streaming KV tiles. O, m, l stay in registers throughout.

Parameters
o_outOutput buffer [head_dim] - pre-allocated
qQ vector for this head [head_dim]
kv_cache_kKV cache K [seq_len * num_kv_heads * head_dim]
kv_cache_vKV cache V [seq_len * num_kv_heads * head_dim]
kv_head_idxWhich KV head this Q head uses
seq_lenCurrent sequence length
head_dimHead dimension
kv_tile_sizeTile size for KV streaming

◆ fused_kernels_compute_kv_tile()

int fused_kernels_compute_kv_tile ( int  l1_size,
int  head_dim,
int  bytes_per_elem 
)

Compute optimal KV tile size for flash attention.

Formula: T_kv ≤ (S_L1 × 0.8 - 2×d×B) / (2×d×B)

Uses 80% of L1 to leave room for prefetch and OS.

Parameters
l1_sizeL1 cache size in bytes (typically 49152 for 48KB)
head_dimHead dimension
bytes_per_elemElement size (2 for FP16, 4 for FP32)
Returns
Optimal KV tile size (multiple of 64 for cache line alignment)

◆ fused_kernels_report_stats()

void fused_kernels_report_stats ( int  hidden,
int  num_layers,
int  seq_len 
)

Report memory savings from mega-fusion.

Parameters
hiddenHidden dimension
num_layersNumber of layers
seq_lenSequence length

◆ fused_kernels_validate_constraints()

int fused_kernels_validate_constraints ( int  l1_size,
int  head_dim,
int  kv_tile_size,
int  bytes_per_elem 
)

Validate cache constraints for fusion.

Parameters
l1_sizeL1 cache size
head_dimHead dimension
kv_tile_sizeKV tile size
bytes_per_elemElement size
Returns
0 if valid, -1 if working set exceeds cache

◆ fused_output_projection_residual()

void fused_output_projection_residual ( float *  output,
const float *  o_all,
const float *  W_o,
const float *  b_o,
const float *  residual,
int  hidden,
int  num_heads,
int  head_dim 
)

Fused output projection with residual add.

Computes O @ W_o + residual, all in registers/L1. Final store includes residual add.

Parameters
outputOutput buffer [hidden] - final DRAM write
o_allConcatenated O from all heads [hidden]
W_oOutput projection weights [hidden, hidden]
b_oOutput bias [hidden] or NULL
residualResidual input [hidden]
hiddenHidden dimension
num_headsNumber of attention heads
head_dimHead dimension

◆ fused_rmsnorm()

void fused_rmsnorm ( const float *  input,
const float *  gamma,
const float *  beta,
float *  output,
int  hidden,
float  eps 
)

Fused RMSNorm - writes to pre-allocated buffer.

Takes input, applies RMSNorm, writes to output buffer. Buffer should be in L1/L2 for best performance.

Parameters
inputInput tensor [hidden]
gammaGamma parameter [hidden]
betaBeta parameter [hidden] or NULL (RMSNorm doesn't use beta typically)
outputOutput buffer [hidden] - pre-allocated, caller owns
hiddenHidden dimension
epsEpsilon for numerical stability

◆ fused_rmsnorm_qkv()

void fused_rmsnorm_qkv ( const float *  input,
const float *  gamma,
const float *  W_qkv,
const float *  b_qkv,
float *  q_out,
float *  k_out,
float *  v_out,
int  hidden,
int  num_heads,
int  num_kv_heads,
int  head_dim,
float  eps 
)

Fused RMSNorm with fused QKV projection.

Computes RMSNorm and immediately projects to Q, K, V. No intermediate buffer - Q/K/V written directly to output buffers.

Parameters
inputInput tensor [hidden]
gammaRMSNorm gamma [hidden]
W_qkvQKV weight matrix [3*hidden, hidden]
b_qkvQKV bias [3*hidden] or NULL
q_outOutput buffer for Q [num_heads * head_dim]
k_outOutput buffer for K [num_kv_heads * head_dim]
v_outOutput buffer for V [num_kv_heads * head_dim]
hiddenHidden dimension
num_headsNumber of attention heads
num_kv_headsNumber of KV heads
head_dimHead dimension
epsRMSNorm epsilon

◆ fused_rope_inplace()

void fused_rope_inplace ( float *  q,
float *  k,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  max_seq 
)

Fused RoPE application (in-place on pre-allocated buffers)

Applies RoPE rotation to Q and K in their buffers. Buffers stay in cache - no extra memory allocation.

Parameters
qQ tensor [num_heads * head_dim] - modified in place
kK tensor [num_kv_heads * head_dim] - modified in place
rope_cosRoPE cos table [max_seq, head_dim/2]
rope_sinRoPE sin table [max_seq, head_dim/2]
posCurrent position in sequence
num_headsNumber of Q heads
num_kv_headsNumber of K/V heads
head_dimHead dimension
max_seqMaximum sequence length

◆ mega_fused_attention()

void mega_fused_attention ( float *  output,
const float *  input,
const float *  residual,
const float *  W_qkv,
const float *  b_qkv,
const float *  W_o,
const float *  b_o,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  seq_len,
int  hidden,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  max_seq,
float  eps 
)

Complete mega-fused attention block.

RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual

All intermediates in L1/L2/registers. Single DRAM round-trip: 4KB in + 4KB out.

Parameters
outputOutput tensor [hidden] - single DRAM write
inputInput tensor [hidden] - single DRAM read
residualResidual input [hidden]
W_qkvQKV weights [3*hidden, hidden]
b_qkvQKV bias [3*hidden] or NULL
W_oOutput projection [hidden, hidden]
b_oOutput bias [hidden] or NULL
kv_cache_kKV cache K [seq, hidden] - updated in place
kv_cache_vKV cache V [seq, hidden] - updated in place
rope_cosRoPE cos [max_seq, head_dim/2]
rope_sinRoPE sin [max_seq, head_dim/2]
posCurrent position
seq_lenCurrent sequence length
hiddenHidden dimension
num_headsNumber of heads
num_kv_headsNumber of KV heads
head_dimHead dimension
max_seqMaximum sequence length
epsRMSNorm epsilon