23 #ifndef FUSED_KERNELS_H
24 #define FUSED_KERNELS_H
133 const float *rope_cos,
134 const float *rope_sin,
164 const float *kv_cache_k,
165 const float *kv_cache_v,
191 const float *kv_cache_k,
192 const float *kv_cache_v,
224 const float *residual,
265 const float *residual,
272 const float *rope_cos,
273 const float *rope_sin,
int fused_kernels_validate_constraints(int l1_size, int head_dim, int kv_tile_size, int bytes_per_elem)
Validate cache constraints for fusion.
void fused_rope_inplace(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int max_seq)
Fused RoPE application (in-place on pre-allocated buffers)
void fused_output_projection_residual(float *output, const float *o_all, const float *W_o, const float *b_o, const float *residual, int hidden, int num_heads, int head_dim)
Fused output projection with residual add.
void fused_flash_attention_all_heads(float *o_out, const float *q_all, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int head_dim, int seq_len, int kv_tile_size)
Fused Flash Attention for all heads (parallel dispatch)
void fused_kernels_report_stats(int hidden, int num_layers, int seq_len)
Report memory savings from mega-fusion.
void fused_rmsnorm_qkv(const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, float *q_out, float *k_out, float *v_out, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
Fused RMSNorm with fused QKV projection.
void fused_flash_attention_head(float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int kv_head_idx, int seq_len, int head_dim, int kv_tile_size)
Fused Flash Attention for single head.
int fused_kernels_compute_kv_tile(int l1_size, int head_dim, int bytes_per_elem)
Compute optimal KV tile size for flash attention.
void fused_rmsnorm(const float *input, const float *gamma, const float *beta, float *output, int hidden, float eps)
Fused RMSNorm - writes to pre-allocated buffer.
void mega_fused_attention(float *output, const float *input, const float *residual, const float *W_qkv, const float *b_qkv, const float *W_o, const float *b_o, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int seq_len, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
Complete mega-fused attention block.