Mega-Fused Attention for AVX (256-bit) and AVX-512 (512-bit) More...
#include <stdint.h>#include <stdlib.h>#include <string.h>#include <math.h>#include "ckernel_engine.h"#include "ckernel_quant.h"#include "ck_features.h"Go to the source code of this file.
Macros | |
| #define | MEGA_KV_TILE 32 |
| #define | MEGA_Q_TILE 32 |
| #define | MEGA_REGS 16 /* 16 YMM registers */ |
| #define | MEGA_STACK_MAX 8192 |
| #define | MEGA_VLEN 8 /* 256 / 32 */ |
| #define | REG_K_TILE "YMM8-YMM11" /* 4 regs for K tile */ |
| #define | REG_O_ACCUM "Stack+L1" /* O in L1 cache */ |
| #define | REG_Q_ACCUM "YMM0-YMM7" /* 8 regs for Q tile */ |
| #define | REG_SOFTMAX "YMM0-YMM1" /* 2 regs for m, l */ |
| #define | REG_TEMP "YMM2-YMM3" /* 2 regs for temps */ |
| #define | REG_V_TILE "YMM12-YMM15" /* 4 regs for V tile */ |
Functions | |
| static float | ck_dot_f32 (const float *a, const float *b, int len) |
| void | mega_fuse_flash_attention_avx (float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim) |
| Flash attention with online softmax (AVX version) More... | |
| static void | mega_fuse_output_proj_residual (const float *attn_token, const float *wo, const float *bo, const float *residual, float *output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim) |
| void | mega_fuse_rmsnorm_qkv_avx (float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps) |
| Fused RMSNorm + QKV for decode (single token) More... | |
| void | mega_fuse_rope_inplace_avx (float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim) |
| Apply RoPE to Q and K (in-place, from L1) More... | |
| void | mega_fused_attention_decode (float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps) |
| Full mega-fused attention for decode. More... | |
Mega-Fused Attention for AVX (256-bit) and AVX-512 (512-bit)
After changes: make test && make llamacpp-parity-full
VIOLATION: Uses malloc for intermediate buffers and memcpy for layout. TODO: Refactor to use bump allocator workspace and strided access.
Holy grail fusion: RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
AVX approach: Keep intermediates in L1 cache (not registers) AVX-512 approach: Keep intermediates in registers
Both achieve the same goal: Eliminate DRAM traffic for intermediates.
Definition in file mega_fused_attention_avx.c.
| #define MEGA_KV_TILE 32 |
Definition at line 118 of file mega_fused_attention_avx.c.
| #define MEGA_Q_TILE 32 |
Definition at line 117 of file mega_fused_attention_avx.c.
| #define MEGA_REGS 16 /* 16 YMM registers */ |
Definition at line 116 of file mega_fused_attention_avx.c.
| #define MEGA_STACK_MAX 8192 |
Definition at line 119 of file mega_fused_attention_avx.c.
| #define MEGA_VLEN 8 /* 256 / 32 */ |
Definition at line 115 of file mega_fused_attention_avx.c.
| #define REG_K_TILE "YMM8-YMM11" /* 4 regs for K tile */ |
Definition at line 123 of file mega_fused_attention_avx.c.
| #define REG_O_ACCUM "Stack+L1" /* O in L1 cache */ |
Definition at line 125 of file mega_fused_attention_avx.c.
| #define REG_Q_ACCUM "YMM0-YMM7" /* 8 regs for Q tile */ |
Definition at line 122 of file mega_fused_attention_avx.c.
| #define REG_SOFTMAX "YMM0-YMM1" /* 2 regs for m, l */ |
Definition at line 126 of file mega_fused_attention_avx.c.
| #define REG_TEMP "YMM2-YMM3" /* 2 regs for temps */ |
Definition at line 127 of file mega_fused_attention_avx.c.
| #define REG_V_TILE "YMM12-YMM15" /* 4 regs for V tile */ |
Definition at line 124 of file mega_fused_attention_avx.c.
|
inlinestatic |
Definition at line 54 of file mega_fused_attention_avx.c.
Referenced by mega_fuse_output_proj_residual(), and mega_fuse_rmsnorm_qkv_avx().
| void mega_fuse_flash_attention_avx | ( | float * | o_out, |
| const float * | q, | ||
| const float * | kv_cache_k, | ||
| const float * | kv_cache_v, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | seq_len, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Flash attention with online softmax (AVX version)
Key insight: O, m, l stay in registers throughout! K/V tiles stream from L2 cache.
| o_out | Output [num_heads * aligned_head_dim] - in registers/L1 |
| q | Q tensor [num_heads * aligned_head_dim] - from L1 |
| kv_cache_k | KV cache K [num_kv_heads * cache_capacity * aligned_head_dim] |
| kv_cache_v | KV cache V [num_kv_heads * cache_capacity * aligned_head_dim] |
| num_heads | Number of heads |
| num_kv_heads | Number of KV heads |
| seq_len | Current sequence length |
| cache_capacity | KV cache capacity (head stride) |
| head_dim | Head dimension |
| aligned_head_dim | Aligned head dimension |
Definition at line 444 of file mega_fused_attention_avx.c.
References MEGA_KV_TILE.
Referenced by mega_fused_attention_decode().
|
static |
Definition at line 551 of file mega_fused_attention_avx.c.
References ck_dot_f32().
Referenced by mega_fused_attention_decode().
| void mega_fuse_rmsnorm_qkv_avx | ( | float * | q_out, |
| float * | k_out, | ||
| float * | v_out, | ||
| const float * | input, | ||
| const float * | gamma, | ||
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| float | eps | ||
| ) |
Fused RMSNorm + QKV for decode (single token)
Intermediates stay in L1/L2. Output buffers are head-major.
Definition at line 143 of file mega_fused_attention_avx.c.
References ck_dot_f32().
Referenced by mega_fused_attention_decode().
| void mega_fuse_rope_inplace_avx | ( | float * | q, |
| float * | k, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Apply RoPE to Q and K (in-place, from L1)
Q and K are already in L1 from QKV projection. Just apply rotation in-place.
Definition at line 311 of file mega_fused_attention_avx.c.
Referenced by mega_fused_attention_decode().
| void mega_fused_attention_decode | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | ln1_gamma, | ||
| const float * | wq, | ||
| const float * | bq, | ||
| const float * | wk, | ||
| const float * | bk, | ||
| const float * | wv, | ||
| const float * | bv, | ||
| const float * | wo, | ||
| const float * | bo, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | cache_capacity, | ||
| float | eps | ||
| ) |
Full mega-fused attention for decode.
Mega-fused attention for decode mode (single token)
RMSNorm → QKV → RoPE → Flash Attn → OutProj + Residual
Definition at line 589 of file mega_fused_attention_avx.c.
References kv_cache_write_head_major(), mega_fuse_flash_attention_avx(), mega_fuse_output_proj_residual(), mega_fuse_rmsnorm_qkv_avx(), mega_fuse_rope_inplace_avx(), and MEGA_STACK_MAX.