← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_avx.c File Reference

Mega-Fused Attention for AVX (256-bit) and AVX-512 (512-bit) More...

#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "ckernel_engine.h"
#include "ckernel_quant.h"
#include "ck_features.h"

Go to the source code of this file.

Macros

#define MEGA_KV_TILE   32
 
#define MEGA_Q_TILE   32
 
#define MEGA_REGS   16 /* 16 YMM registers */
 
#define MEGA_STACK_MAX   8192
 
#define MEGA_VLEN   8 /* 256 / 32 */
 
#define REG_K_TILE   "YMM8-YMM11" /* 4 regs for K tile */
 
#define REG_O_ACCUM   "Stack+L1" /* O in L1 cache */
 
#define REG_Q_ACCUM   "YMM0-YMM7" /* 8 regs for Q tile */
 
#define REG_SOFTMAX   "YMM0-YMM1" /* 2 regs for m, l */
 
#define REG_TEMP   "YMM2-YMM3" /* 2 regs for temps */
 
#define REG_V_TILE   "YMM12-YMM15" /* 4 regs for V tile */
 

Functions

static float ck_dot_f32 (const float *a, const float *b, int len)
 
void mega_fuse_flash_attention_avx (float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)
 Flash attention with online softmax (AVX version) More...
 
static void mega_fuse_output_proj_residual (const float *attn_token, const float *wo, const float *bo, const float *residual, float *output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)
 
void mega_fuse_rmsnorm_qkv_avx (float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)
 Fused RMSNorm + QKV for decode (single token) More...
 
void mega_fuse_rope_inplace_avx (float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)
 Apply RoPE to Q and K (in-place, from L1) More...
 
void mega_fused_attention_decode (float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
 Full mega-fused attention for decode. More...
 

Detailed Description

Mega-Fused Attention for AVX (256-bit) and AVX-512 (512-bit)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. NO memcpy for layout - use strided access, not copies
  4. API must define: inputs, outputs, workspace, and memory layouts
  5. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

VIOLATION: Uses malloc for intermediate buffers and memcpy for layout. TODO: Refactor to use bump allocator workspace and strided access.

Holy grail fusion: RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual

AVX approach: Keep intermediates in L1 cache (not registers) AVX-512 approach: Keep intermediates in registers

Both achieve the same goal: Eliminate DRAM traffic for intermediates.

Definition in file mega_fused_attention_avx.c.

Macro Definition Documentation

◆ MEGA_KV_TILE

#define MEGA_KV_TILE   32

Definition at line 118 of file mega_fused_attention_avx.c.

◆ MEGA_Q_TILE

#define MEGA_Q_TILE   32

Definition at line 117 of file mega_fused_attention_avx.c.

◆ MEGA_REGS

#define MEGA_REGS   16 /* 16 YMM registers */

Definition at line 116 of file mega_fused_attention_avx.c.

◆ MEGA_STACK_MAX

#define MEGA_STACK_MAX   8192

Definition at line 119 of file mega_fused_attention_avx.c.

◆ MEGA_VLEN

#define MEGA_VLEN   8 /* 256 / 32 */

Definition at line 115 of file mega_fused_attention_avx.c.

◆ REG_K_TILE

#define REG_K_TILE   "YMM8-YMM11" /* 4 regs for K tile */

Definition at line 123 of file mega_fused_attention_avx.c.

◆ REG_O_ACCUM

#define REG_O_ACCUM   "Stack+L1" /* O in L1 cache */

Definition at line 125 of file mega_fused_attention_avx.c.

◆ REG_Q_ACCUM

#define REG_Q_ACCUM   "YMM0-YMM7" /* 8 regs for Q tile */

Definition at line 122 of file mega_fused_attention_avx.c.

◆ REG_SOFTMAX

#define REG_SOFTMAX   "YMM0-YMM1" /* 2 regs for m, l */

Definition at line 126 of file mega_fused_attention_avx.c.

◆ REG_TEMP

#define REG_TEMP   "YMM2-YMM3" /* 2 regs for temps */

Definition at line 127 of file mega_fused_attention_avx.c.

◆ REG_V_TILE

#define REG_V_TILE   "YMM12-YMM15" /* 4 regs for V tile */

Definition at line 124 of file mega_fused_attention_avx.c.

Function Documentation

◆ ck_dot_f32()

static float ck_dot_f32 ( const float *  a,
const float *  b,
int  len 
)
inlinestatic

Definition at line 54 of file mega_fused_attention_avx.c.

55 {
56 #if defined(__AVX512F__)
57  __m512 acc = _mm512_setzero_ps();
58  int i = 0;
59  for (; i <= len - 16; i += 16) {
60  __m512 va = _mm512_loadu_ps(a + i);
61  __m512 vb = _mm512_loadu_ps(b + i);
62  acc = _mm512_fmadd_ps(va, vb, acc);
63  }
64  float sum = _mm512_reduce_add_ps(acc);
65  for (; i < len; ++i) {
66  sum += a[i] * b[i];
67  }
68  return sum;
69 #elif defined(__AVX__)
70  __m256 acc = _mm256_setzero_ps();
71  int i = 0;
72  for (; i <= len - 8; i += 8) {
73  __m256 va = _mm256_loadu_ps(a + i);
74  __m256 vb = _mm256_loadu_ps(b + i);
75  acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
76  }
77  float sum = ck_hsum256_ps(acc);
78  for (; i < len; ++i) {
79  sum += a[i] * b[i];
80  }
81  return sum;
82 #else
83  float sum = 0.0f;
84  for (int i = 0; i < len; ++i) {
85  sum += a[i] * b[i];
86  }
87  return sum;
88 #endif
89 }

Referenced by mega_fuse_output_proj_residual(), and mega_fuse_rmsnorm_qkv_avx().

◆ mega_fuse_flash_attention_avx()

void mega_fuse_flash_attention_avx ( float *  o_out,
const float *  q,
const float *  kv_cache_k,
const float *  kv_cache_v,
int  num_heads,
int  num_kv_heads,
int  seq_len,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Flash attention with online softmax (AVX version)

Key insight: O, m, l stay in registers throughout! K/V tiles stream from L2 cache.

Parameters
o_outOutput [num_heads * aligned_head_dim] - in registers/L1
qQ tensor [num_heads * aligned_head_dim] - from L1
kv_cache_kKV cache K [num_kv_heads * cache_capacity * aligned_head_dim]
kv_cache_vKV cache V [num_kv_heads * cache_capacity * aligned_head_dim]
num_headsNumber of heads
num_kv_headsNumber of KV heads
seq_lenCurrent sequence length
cache_capacityKV cache capacity (head stride)
head_dimHead dimension
aligned_head_dimAligned head dimension

Definition at line 444 of file mega_fused_attention_avx.c.

455 {
456  const int hd = head_dim;
457  const float scale = 1.0f / sqrtf((float)hd);
458  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
459 
460  for (int h = 0; h < num_heads; h++) {
461  const float *q_h = q + (size_t)h * (size_t)aligned_head_dim;
462  const int kv_idx = h % num_kv_heads;
463  const float *k_cache = kv_cache_k + (size_t)kv_idx * head_stride;
464  const float *v_cache = kv_cache_v + (size_t)kv_idx * head_stride;
465 
466  /* O, m, l in registers for this head */
467  float o_h[aligned_head_dim]; /* in L1 */
468  float m = -INFINITY; /* running max */
469  float l = 0.0f; /* running sum */
470 
471  /* Initialize O to zeros */
472  memset(o_h, 0, (size_t)aligned_head_dim * sizeof(float));
473 
474  /* Iterate over KV cache tiles */
475  for (int t = 0; t < seq_len; t += MEGA_KV_TILE) {
476  int tile_end = t + MEGA_KV_TILE;
477  if (tile_end > seq_len) tile_end = seq_len;
478  int tile_size = tile_end - t;
479 
480  /* Load K tile from L2 cache */
481  float k_tile[MEGA_KV_TILE * hd];
482  for (int i = 0; i < tile_size; i++) {
483  memcpy(k_tile + (size_t)i * (size_t)hd,
484  k_cache + (size_t)(t + i) * (size_t)aligned_head_dim,
485  (size_t)hd * sizeof(float));
486  }
487 
488  /* S_ij = Q @ K_tile.T / sqrt(d) - in registers */
489  float s_row[MEGA_KV_TILE];
490  for (int j = 0; j < tile_size; j++) {
491  s_row[j] = 0.0f;
492  for (int i = 0; i < hd; i++) {
493  s_row[j] += q_h[i] * k_tile[j * hd + i];
494  }
495  s_row[j] *= scale;
496  }
497 
498  /* Online softmax update */
499  float m_new = m;
500  for (int j = 0; j < tile_size; j++) {
501  if (s_row[j] > m_new) m_new = s_row[j];
502  }
503 
504  float l_new = 0.0f;
505  for (int j = 0; j < tile_size; j++) {
506  float p = expf(s_row[j] - m_new);
507  s_row[j] = p;
508  l_new += p;
509  }
510 
511  /* Scale O by exp(m - m_new) and add P @ V */
512  float exp_m_diff = expf(m - m_new);
513  for (int i = 0; i < hd; i++) {
514  o_h[i] *= exp_m_diff;
515  }
516 
517  /* Load V tile and accumulate */
518  for (int j = 0; j < tile_size; j++) {
519  float p = s_row[j];
520  for (int i = 0; i < hd; i++) {
521  o_h[i] += p * v_cache[(size_t)(t + j) * (size_t)aligned_head_dim + (size_t)i];
522  }
523  }
524 
525  l = l * exp_m_diff + l_new;
526  m = m_new;
527  }
528 
529  /* Normalize by l */
530  for (int i = 0; i < hd; i++) {
531  o_h[i] /= l;
532  }
533  for (int i = hd; i < aligned_head_dim; ++i) {
534  o_h[i] = 0.0f;
535  }
536 
537  /* Store O - still in L1, goes to output projection */
538  memcpy(o_out + (size_t)h * (size_t)aligned_head_dim,
539  o_h,
540  (size_t)aligned_head_dim * sizeof(float));
541  }
542 }
#define MEGA_KV_TILE

References MEGA_KV_TILE.

Referenced by mega_fused_attention_decode().

◆ mega_fuse_output_proj_residual()

static void mega_fuse_output_proj_residual ( const float *  attn_token,
const float *  wo,
const float *  bo,
const float *  residual,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  head_dim,
int  aligned_head_dim 
)
static

Definition at line 551 of file mega_fused_attention_avx.c.

562 {
563  if (!attn_token || !wo || !output) {
564  return;
565  }
566 
567  const size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
568 
569  for (int j = 0; j < embed_dim; ++j) {
570  float sum = bo ? bo[j] : 0.0f;
571  for (int h = 0; h < num_heads; ++h) {
572  const float *o_h = attn_token + (size_t)h * (size_t)aligned_head_dim;
573  const float *wo_row = wo + (size_t)h * head_weight_stride + (size_t)j * (size_t)aligned_head_dim;
574  sum += ck_dot_f32(o_h, wo_row, head_dim);
575  }
576  output[j] = sum + (residual ? residual[j] : 0.0f);
577  }
578 
579  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
580  output[j] = 0.0f;
581  }
582 }
static float ck_dot_f32(const float *a, const float *b, int len)

References ck_dot_f32().

Referenced by mega_fused_attention_decode().

◆ mega_fuse_rmsnorm_qkv_avx()

void mega_fuse_rmsnorm_qkv_avx ( float *  q_out,
float *  k_out,
float *  v_out,
const float *  input,
const float *  gamma,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
float  eps 
)

Fused RMSNorm + QKV for decode (single token)

Intermediates stay in L1/L2. Output buffers are head-major.

Definition at line 143 of file mega_fused_attention_avx.c.

162 {
163  if (!q_out || !k_out || !v_out || !input || !wq || !wk || !wv) {
164  return;
165  }
166  if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0) {
167  return;
168  }
169 
170  float ln1_row[aligned_embed_dim];
171  float sum_sq = 0.0f;
172 
173 #if defined(__AVX512F__)
174  __m512 sum_vec = _mm512_setzero_ps();
175  int i = 0;
176  for (; i + 16 <= embed_dim; i += 16) {
177  __m512 xv = _mm512_loadu_ps(input + i);
178  sum_vec = _mm512_fmadd_ps(xv, xv, sum_vec);
179  }
180  sum_sq = _mm512_reduce_add_ps(sum_vec);
181  for (; i < embed_dim; ++i) {
182  sum_sq += input[i] * input[i];
183  }
184 #elif defined(__AVX__)
185  __m256 sum_vec = _mm256_setzero_ps();
186  int i = 0;
187  for (; i + 8 <= embed_dim; i += 8) {
188  __m256 xv = _mm256_loadu_ps(input + i);
189  sum_vec = _mm256_add_ps(sum_vec, _mm256_mul_ps(xv, xv));
190  }
191  sum_sq = ck_hsum256_ps(sum_vec);
192  for (; i < embed_dim; ++i) {
193  sum_sq += input[i] * input[i];
194  }
195 #else
196  for (int i = 0; i < embed_dim; ++i) {
197  sum_sq += input[i] * input[i];
198  }
199 #endif
200 
201  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
202 
203 #if defined(__AVX512F__)
204  if (gamma) {
205  __m512 rstd_vec = _mm512_set1_ps(rstd);
206  int j = 0;
207  for (; j + 16 <= embed_dim; j += 16) {
208  __m512 xv = _mm512_loadu_ps(input + j);
209  __m512 gv = _mm512_loadu_ps(gamma + j);
210  __m512 yv = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
211  _mm512_storeu_ps(ln1_row + j, yv);
212  }
213  for (; j < embed_dim; ++j) {
214  ln1_row[j] = input[j] * rstd * gamma[j];
215  }
216  } else {
217  __m512 rstd_vec = _mm512_set1_ps(rstd);
218  int j = 0;
219  for (; j + 16 <= embed_dim; j += 16) {
220  __m512 xv = _mm512_loadu_ps(input + j);
221  __m512 yv = _mm512_mul_ps(xv, rstd_vec);
222  _mm512_storeu_ps(ln1_row + j, yv);
223  }
224  for (; j < embed_dim; ++j) {
225  ln1_row[j] = input[j] * rstd;
226  }
227  }
228 #elif defined(__AVX__)
229  if (gamma) {
230  __m256 rstd_vec = _mm256_set1_ps(rstd);
231  int j = 0;
232  for (; j + 8 <= embed_dim; j += 8) {
233  __m256 xv = _mm256_loadu_ps(input + j);
234  __m256 gv = _mm256_loadu_ps(gamma + j);
235  __m256 yv = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
236  _mm256_storeu_ps(ln1_row + j, yv);
237  }
238  for (; j < embed_dim; ++j) {
239  ln1_row[j] = input[j] * rstd * gamma[j];
240  }
241  } else {
242  __m256 rstd_vec = _mm256_set1_ps(rstd);
243  int j = 0;
244  for (; j + 8 <= embed_dim; j += 8) {
245  __m256 xv = _mm256_loadu_ps(input + j);
246  __m256 yv = _mm256_mul_ps(xv, rstd_vec);
247  _mm256_storeu_ps(ln1_row + j, yv);
248  }
249  for (; j < embed_dim; ++j) {
250  ln1_row[j] = input[j] * rstd;
251  }
252  }
253 #else
254  for (int j = 0; j < embed_dim; ++j) {
255  ln1_row[j] = input[j] * rstd * (gamma ? gamma[j] : 1.0f);
256  }
257 #endif
258 
259  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
260  ln1_row[j] = 0.0f;
261  }
262 
263  const size_t head_w_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
264 
265  for (int h = 0; h < num_heads; ++h) {
266  const float *wq_h = wq + (size_t)h * head_w_stride;
267  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
268  float *q_h = q_out + (size_t)h * (size_t)aligned_head_dim;
269  for (int d = 0; d < head_dim; ++d) {
270  const float *row = wq_h + (size_t)d * (size_t)aligned_embed_dim;
271  float sum = ck_dot_f32(ln1_row, row, aligned_embed_dim);
272  q_h[d] = sum + (bq_h ? bq_h[d] : 0.0f);
273  }
274  for (int d = head_dim; d < aligned_head_dim; ++d) {
275  q_h[d] = 0.0f;
276  }
277  }
278 
279  for (int h = 0; h < num_kv_heads; ++h) {
280  const float *wk_h = wk + (size_t)h * head_w_stride;
281  const float *wv_h = wv + (size_t)h * head_w_stride;
282  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
283  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
284  float *k_h = k_out + (size_t)h * (size_t)aligned_head_dim;
285  float *v_h = v_out + (size_t)h * (size_t)aligned_head_dim;
286  for (int d = 0; d < head_dim; ++d) {
287  const float *wk_row = wk_h + (size_t)d * (size_t)aligned_embed_dim;
288  const float *wv_row = wv_h + (size_t)d * (size_t)aligned_embed_dim;
289  float k_sum = ck_dot_f32(ln1_row, wk_row, aligned_embed_dim);
290  float v_sum = ck_dot_f32(ln1_row, wv_row, aligned_embed_dim);
291  k_h[d] = k_sum + (bk_h ? bk_h[d] : 0.0f);
292  v_h[d] = v_sum + (bv_h ? bv_h[d] : 0.0f);
293  }
294  for (int d = head_dim; d < aligned_head_dim; ++d) {
295  k_h[d] = 0.0f;
296  v_h[d] = 0.0f;
297  }
298  }
299 }

References ck_dot_f32().

Referenced by mega_fused_attention_decode().

◆ mega_fuse_rope_inplace_avx()

void mega_fuse_rope_inplace_avx ( float *  q,
float *  k,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim 
)

Apply RoPE to Q and K (in-place, from L1)

Q and K are already in L1 from QKV projection. Just apply rotation in-place.

Definition at line 311 of file mega_fused_attention_avx.c.

321 {
322  if (!q || !k || !rope_cos || !rope_sin || head_dim <= 0 || aligned_head_dim <= 0) {
323  return;
324  }
325  if ((head_dim & 1) != 0) {
326  return;
327  }
328 
329  int half = head_dim / 2;
330  const float *cos_ptr = rope_cos + (size_t)pos * (size_t)half;
331  const float *sin_ptr = rope_sin + (size_t)pos * (size_t)half;
332 
333  for (int h = 0; h < num_heads; ++h) {
334  float *q_h = q + (size_t)h * (size_t)aligned_head_dim;
335  int i = 0;
336 #if defined(__AVX512F__)
337  for (; i + 16 <= half; i += 16) {
338  __m512 q0 = _mm512_loadu_ps(q_h + i);
339  __m512 q1 = _mm512_loadu_ps(q_h + i + half);
340  __m512 cos = _mm512_loadu_ps(cos_ptr + i);
341  __m512 sin = _mm512_loadu_ps(sin_ptr + i);
342 
343  __m512 q_rot0 = _mm512_sub_ps(_mm512_mul_ps(q0, cos), _mm512_mul_ps(q1, sin));
344  __m512 q_rot1 = _mm512_add_ps(_mm512_mul_ps(q0, sin), _mm512_mul_ps(q1, cos));
345 
346  _mm512_storeu_ps(q_h + i, q_rot0);
347  _mm512_storeu_ps(q_h + i + half, q_rot1);
348  }
349 #elif defined(__AVX__)
350  for (; i + 8 <= half; i += 8) {
351  __m256 q0 = _mm256_loadu_ps(q_h + i);
352  __m256 q1 = _mm256_loadu_ps(q_h + i + half);
353  __m256 cos = _mm256_loadu_ps(cos_ptr + i);
354  __m256 sin = _mm256_loadu_ps(sin_ptr + i);
355 
356  __m256 q_rot0 = _mm256_sub_ps(_mm256_mul_ps(q0, cos), _mm256_mul_ps(q1, sin));
357  __m256 q_rot1 = _mm256_add_ps(_mm256_mul_ps(q0, sin), _mm256_mul_ps(q1, cos));
358 
359  _mm256_storeu_ps(q_h + i, q_rot0);
360  _mm256_storeu_ps(q_h + i + half, q_rot1);
361  }
362 #endif
363  for (; i < half; ++i) {
364  float q0 = q_h[i];
365  float q1 = q_h[i + half];
366  float c = cos_ptr[i];
367  float s = sin_ptr[i];
368  q_h[i] = q0 * c - q1 * s;
369  q_h[i + half] = q0 * s + q1 * c;
370  }
371  for (int d = head_dim; d < aligned_head_dim; ++d) {
372  q_h[d] = 0.0f;
373  }
374  }
375 
376  for (int h = 0; h < num_kv_heads; ++h) {
377  float *k_h = k + (size_t)h * (size_t)aligned_head_dim;
378  int i = 0;
379 #if defined(__AVX512F__)
380  for (; i + 16 <= half; i += 16) {
381  __m512 k0 = _mm512_loadu_ps(k_h + i);
382  __m512 k1 = _mm512_loadu_ps(k_h + i + half);
383  __m512 cos = _mm512_loadu_ps(cos_ptr + i);
384  __m512 sin = _mm512_loadu_ps(sin_ptr + i);
385 
386  __m512 k_rot0 = _mm512_sub_ps(_mm512_mul_ps(k0, cos), _mm512_mul_ps(k1, sin));
387  __m512 k_rot1 = _mm512_add_ps(_mm512_mul_ps(k0, sin), _mm512_mul_ps(k1, cos));
388 
389  _mm512_storeu_ps(k_h + i, k_rot0);
390  _mm512_storeu_ps(k_h + i + half, k_rot1);
391  }
392 #elif defined(__AVX__)
393  for (; i + 8 <= half; i += 8) {
394  __m256 k0 = _mm256_loadu_ps(k_h + i);
395  __m256 k1 = _mm256_loadu_ps(k_h + i + half);
396  __m256 cos = _mm256_loadu_ps(cos_ptr + i);
397  __m256 sin = _mm256_loadu_ps(sin_ptr + i);
398 
399  __m256 k_rot0 = _mm256_sub_ps(_mm256_mul_ps(k0, cos), _mm256_mul_ps(k1, sin));
400  __m256 k_rot1 = _mm256_add_ps(_mm256_mul_ps(k0, sin), _mm256_mul_ps(k1, cos));
401 
402  _mm256_storeu_ps(k_h + i, k_rot0);
403  _mm256_storeu_ps(k_h + i + half, k_rot1);
404  }
405 #endif
406  for (; i < half; ++i) {
407  float k0 = k_h[i];
408  float k1 = k_h[i + half];
409  float c = cos_ptr[i];
410  float s = sin_ptr[i];
411  k_h[i] = k0 * c - k1 * s;
412  k_h[i + half] = k0 * s + k1 * c;
413  }
414  for (int d = head_dim; d < aligned_head_dim; ++d) {
415  k_h[d] = 0.0f;
416  }
417  }
418 }

Referenced by mega_fused_attention_decode().

◆ mega_fused_attention_decode()

void mega_fused_attention_decode ( float *  output,
const float *  input,
const float *  residual,
const float *  ln1_gamma,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
const float *  wo,
const float *  bo,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  cache_capacity,
float  eps 
)

Full mega-fused attention for decode.

Mega-fused attention for decode mode (single token)

RMSNorm → QKV → RoPE → Flash Attn → OutProj + Residual

Definition at line 589 of file mega_fused_attention_avx.c.

611 {
612  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
613  !kv_cache_k || !kv_cache_v) {
614  return;
615  }
616  if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0 ||
617  num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
618  return;
619  }
620  if (pos < 0 || pos >= cache_capacity) {
621  return;
622  }
623  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
624  return;
625  }
626 
627  const size_t q_elems = (size_t)num_heads * (size_t)aligned_head_dim;
628  const size_t kv_elems = (size_t)num_kv_heads * (size_t)aligned_head_dim;
629 
630  float q_stack[MEGA_STACK_MAX];
631  float k_stack[MEGA_STACK_MAX];
632  float v_stack[MEGA_STACK_MAX];
633  float o_stack[MEGA_STACK_MAX];
634 
635  float *q = q_stack;
636  float *k = k_stack;
637  float *v = v_stack;
638  float *o = o_stack;
639 
640  int free_q = 0;
641  int free_k = 0;
642  int free_v = 0;
643  int free_o = 0;
644 
645  if (q_elems > MEGA_STACK_MAX) {
646  q = (float *)malloc(q_elems * sizeof(float));
647  if (!q) {
648  return;
649  }
650  free_q = 1;
651  }
652  if (kv_elems > MEGA_STACK_MAX) {
653  k = (float *)malloc(kv_elems * sizeof(float));
654  if (!k) {
655  if (free_q) free(q);
656  return;
657  }
658  v = (float *)malloc(kv_elems * sizeof(float));
659  if (!v) {
660  if (free_q) free(q);
661  free(k);
662  return;
663  }
664  free_k = 1;
665  free_v = 1;
666  }
667  if (q_elems > MEGA_STACK_MAX) {
668  o = (float *)malloc(q_elems * sizeof(float));
669  if (!o) {
670  if (free_q) free(q);
671  if (free_k) free(k);
672  if (free_v) free(v);
673  return;
674  }
675  free_o = 1;
676  }
677 
678  mega_fuse_rmsnorm_qkv_avx(q, k, v, input, ln1_gamma,
679  wq, bq, wk, bk, wv, bv,
680  embed_dim, aligned_embed_dim,
681  num_heads, num_kv_heads,
682  head_dim, aligned_head_dim, eps);
683 
684  if (rope_cos && rope_sin) {
685  mega_fuse_rope_inplace_avx(q, k, rope_cos, rope_sin, pos,
686  num_heads, num_kv_heads,
687  head_dim, aligned_head_dim);
688  }
689 
691  kv_cache_k, kv_cache_v,
692  num_kv_heads, pos,
693  cache_capacity,
694  head_dim, aligned_head_dim);
695 
696  mega_fuse_flash_attention_avx(o, q, kv_cache_k, kv_cache_v,
697  num_heads, num_kv_heads,
698  pos + 1, cache_capacity,
699  head_dim, aligned_head_dim);
700 
701  mega_fuse_output_proj_residual(o, wo, bo, residual, output,
702  embed_dim, aligned_embed_dim,
703  num_heads, head_dim, aligned_head_dim);
704 
705  if (free_q) free(q);
706  if (free_k) free(k);
707  if (free_v) free(v);
708  if (free_o) free(o);
709 }
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
static void mega_fuse_output_proj_residual(const float *attn_token, const float *wo, const float *bo, const float *residual, float *output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)
void mega_fuse_flash_attention_avx(float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)
Flash attention with online softmax (AVX version)
void mega_fuse_rmsnorm_qkv_avx(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)
Fused RMSNorm + QKV for decode (single token)
void mega_fuse_rope_inplace_avx(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)
Apply RoPE to Q and K (in-place, from L1)
#define MEGA_STACK_MAX

References kv_cache_write_head_major(), mega_fuse_flash_attention_avx(), mega_fuse_output_proj_residual(), mega_fuse_rmsnorm_qkv_avx(), mega_fuse_rope_inplace_avx(), and MEGA_STACK_MAX.