← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_avx.c
Go to the documentation of this file.
1 /**
2  * @file mega_fused_attention_avx.c
3  * @brief Mega-Fused Attention for AVX (256-bit) and AVX-512 (512-bit)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. NO memcpy for layout - use strided access, not copies
10  * 4. API must define: inputs, outputs, workspace, and memory layouts
11  * 5. Pure computation - deterministic, no side effects
12  *
13  * After changes: make test && make llamacpp-parity-full
14  *
15  * VIOLATION: Uses malloc for intermediate buffers and memcpy for layout.
16  * TODO: Refactor to use bump allocator workspace and strided access.
17  *
18  * Holy grail fusion: RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
19  *
20  * AVX approach: Keep intermediates in L1 cache (not registers)
21  * AVX-512 approach: Keep intermediates in registers
22  *
23  * Both achieve the same goal: Eliminate DRAM traffic for intermediates.
24  */
25 
26 #include <stdint.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <math.h>
30 
31 #include "ckernel_engine.h"
32 #include "ckernel_quant.h"
33 #include "ck_features.h"
34 
35 #if defined(__AVX__) || defined(__AVX512F__)
36 #include <immintrin.h>
37 #endif
38 
39 /* Local helpers (keep this file self-contained). */
40 #if defined(__AVX__) && !defined(__AVX512F__)
41 static inline float ck_hsum256_ps(__m256 v)
42 {
43  __m128 lo = _mm256_castps256_ps128(v);
44  __m128 hi = _mm256_extractf128_ps(v, 1);
45  __m128 sum128 = _mm_add_ps(lo, hi);
46  __m128 shuf = _mm_movehdup_ps(sum128);
47  __m128 sums = _mm_add_ps(sum128, shuf);
48  shuf = _mm_movehl_ps(shuf, sums);
49  sums = _mm_add_ss(sums, shuf);
50  return _mm_cvtss_f32(sums);
51 }
52 #endif
53 
54 static inline float ck_dot_f32(const float *a, const float *b, int len)
55 {
56 #if defined(__AVX512F__)
57  __m512 acc = _mm512_setzero_ps();
58  int i = 0;
59  for (; i <= len - 16; i += 16) {
60  __m512 va = _mm512_loadu_ps(a + i);
61  __m512 vb = _mm512_loadu_ps(b + i);
62  acc = _mm512_fmadd_ps(va, vb, acc);
63  }
64  float sum = _mm512_reduce_add_ps(acc);
65  for (; i < len; ++i) {
66  sum += a[i] * b[i];
67  }
68  return sum;
69 #elif defined(__AVX__)
70  __m256 acc = _mm256_setzero_ps();
71  int i = 0;
72  for (; i <= len - 8; i += 8) {
73  __m256 va = _mm256_loadu_ps(a + i);
74  __m256 vb = _mm256_loadu_ps(b + i);
75  acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
76  }
77  float sum = ck_hsum256_ps(acc);
78  for (; i < len; ++i) {
79  sum += a[i] * b[i];
80  }
81  return sum;
82 #else
83  float sum = 0.0f;
84  for (int i = 0; i < len; ++i) {
85  sum += a[i] * b[i];
86  }
87  return sum;
88 #endif
89 }
90 
91 /*============================================================================
92  * Configuration - AVX vs AVX-512
93  *============================================================================*/
94 
95 #if defined(__AVX512F__)
96 
97 /* AVX-512: Can keep more in registers */
98 #define MEGA_VLEN 16 /* 512 / 32 */
99 #define MEGA_REGS 32 /* 32 ZMM registers */
100 #define MEGA_Q_TILE 64
101 #define MEGA_KV_TILE 64
102 #define MEGA_STACK_MAX 8192
103 
104 /* Register allocation for AVX-512 */
105 #define REG_Q_ACCUM "ZMM0-ZMM11" /* 12 regs for Q tile */
106 #define REG_K_TILE "ZMM12-ZMM17" /* 6 regs for K tile */
107 #define REG_V_TILE "ZMM18-ZMM23" /* 6 regs for V tile */
108 #define REG_O_ACCUM "ZMM24-ZMM27" /* 4 regs for O tile */
109 #define REG_SOFTMAX "ZMM28-ZMM29" /* 2 regs for m, l */
110 #define REG_TEMP "ZMM30-ZMM31" /* 2 regs for temps */
111 
112 #else
113 
114 /* AVX: Smaller tiles to fit in L1 cache */
115 #define MEGA_VLEN 8 /* 256 / 32 */
116 #define MEGA_REGS 16 /* 16 YMM registers */
117 #define MEGA_Q_TILE 32
118 #define MEGA_KV_TILE 32
119 #define MEGA_STACK_MAX 8192
120 
121 /* Register allocation for AVX - use L1 cache for larger working set */
122 #define REG_Q_ACCUM "YMM0-YMM7" /* 8 regs for Q tile */
123 #define REG_K_TILE "YMM8-YMM11" /* 4 regs for K tile */
124 #define REG_V_TILE "YMM12-YMM15" /* 4 regs for V tile */
125 #define REG_O_ACCUM "Stack+L1" /* O in L1 cache */
126 #define REG_SOFTMAX "YMM0-YMM1" /* 2 regs for m, l */
127 #define REG_TEMP "YMM2-YMM3" /* 2 regs for temps */
128 
129 #endif
130 
131 /*============================================================================
132  * Phase 1: RMSNorm + QKV Fusion (AVX version)
133  *
134  * Keep ln1_row in stack buffer, not DRAM.
135  * Q/K/V go directly to next operation.
136  *============================================================================*/
137 
138 /**
139  * @brief Fused RMSNorm + QKV for decode (single token)
140  *
141  * Intermediates stay in L1/L2. Output buffers are head-major.
142  */
144  float *q_out,
145  float *k_out,
146  float *v_out,
147  const float *input,
148  const float *gamma,
149  const float *wq,
150  const float *bq,
151  const float *wk,
152  const float *bk,
153  const float *wv,
154  const float *bv,
155  int embed_dim,
156  int aligned_embed_dim,
157  int num_heads,
158  int num_kv_heads,
159  int head_dim,
160  int aligned_head_dim,
161  float eps)
162 {
163  if (!q_out || !k_out || !v_out || !input || !wq || !wk || !wv) {
164  return;
165  }
166  if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0) {
167  return;
168  }
169 
170  float ln1_row[aligned_embed_dim];
171  float sum_sq = 0.0f;
172 
173 #if defined(__AVX512F__)
174  __m512 sum_vec = _mm512_setzero_ps();
175  int i = 0;
176  for (; i + 16 <= embed_dim; i += 16) {
177  __m512 xv = _mm512_loadu_ps(input + i);
178  sum_vec = _mm512_fmadd_ps(xv, xv, sum_vec);
179  }
180  sum_sq = _mm512_reduce_add_ps(sum_vec);
181  for (; i < embed_dim; ++i) {
182  sum_sq += input[i] * input[i];
183  }
184 #elif defined(__AVX__)
185  __m256 sum_vec = _mm256_setzero_ps();
186  int i = 0;
187  for (; i + 8 <= embed_dim; i += 8) {
188  __m256 xv = _mm256_loadu_ps(input + i);
189  sum_vec = _mm256_add_ps(sum_vec, _mm256_mul_ps(xv, xv));
190  }
191  sum_sq = ck_hsum256_ps(sum_vec);
192  for (; i < embed_dim; ++i) {
193  sum_sq += input[i] * input[i];
194  }
195 #else
196  for (int i = 0; i < embed_dim; ++i) {
197  sum_sq += input[i] * input[i];
198  }
199 #endif
200 
201  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
202 
203 #if defined(__AVX512F__)
204  if (gamma) {
205  __m512 rstd_vec = _mm512_set1_ps(rstd);
206  int j = 0;
207  for (; j + 16 <= embed_dim; j += 16) {
208  __m512 xv = _mm512_loadu_ps(input + j);
209  __m512 gv = _mm512_loadu_ps(gamma + j);
210  __m512 yv = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
211  _mm512_storeu_ps(ln1_row + j, yv);
212  }
213  for (; j < embed_dim; ++j) {
214  ln1_row[j] = input[j] * rstd * gamma[j];
215  }
216  } else {
217  __m512 rstd_vec = _mm512_set1_ps(rstd);
218  int j = 0;
219  for (; j + 16 <= embed_dim; j += 16) {
220  __m512 xv = _mm512_loadu_ps(input + j);
221  __m512 yv = _mm512_mul_ps(xv, rstd_vec);
222  _mm512_storeu_ps(ln1_row + j, yv);
223  }
224  for (; j < embed_dim; ++j) {
225  ln1_row[j] = input[j] * rstd;
226  }
227  }
228 #elif defined(__AVX__)
229  if (gamma) {
230  __m256 rstd_vec = _mm256_set1_ps(rstd);
231  int j = 0;
232  for (; j + 8 <= embed_dim; j += 8) {
233  __m256 xv = _mm256_loadu_ps(input + j);
234  __m256 gv = _mm256_loadu_ps(gamma + j);
235  __m256 yv = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
236  _mm256_storeu_ps(ln1_row + j, yv);
237  }
238  for (; j < embed_dim; ++j) {
239  ln1_row[j] = input[j] * rstd * gamma[j];
240  }
241  } else {
242  __m256 rstd_vec = _mm256_set1_ps(rstd);
243  int j = 0;
244  for (; j + 8 <= embed_dim; j += 8) {
245  __m256 xv = _mm256_loadu_ps(input + j);
246  __m256 yv = _mm256_mul_ps(xv, rstd_vec);
247  _mm256_storeu_ps(ln1_row + j, yv);
248  }
249  for (; j < embed_dim; ++j) {
250  ln1_row[j] = input[j] * rstd;
251  }
252  }
253 #else
254  for (int j = 0; j < embed_dim; ++j) {
255  ln1_row[j] = input[j] * rstd * (gamma ? gamma[j] : 1.0f);
256  }
257 #endif
258 
259  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
260  ln1_row[j] = 0.0f;
261  }
262 
263  const size_t head_w_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
264 
265  for (int h = 0; h < num_heads; ++h) {
266  const float *wq_h = wq + (size_t)h * head_w_stride;
267  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
268  float *q_h = q_out + (size_t)h * (size_t)aligned_head_dim;
269  for (int d = 0; d < head_dim; ++d) {
270  const float *row = wq_h + (size_t)d * (size_t)aligned_embed_dim;
271  float sum = ck_dot_f32(ln1_row, row, aligned_embed_dim);
272  q_h[d] = sum + (bq_h ? bq_h[d] : 0.0f);
273  }
274  for (int d = head_dim; d < aligned_head_dim; ++d) {
275  q_h[d] = 0.0f;
276  }
277  }
278 
279  for (int h = 0; h < num_kv_heads; ++h) {
280  const float *wk_h = wk + (size_t)h * head_w_stride;
281  const float *wv_h = wv + (size_t)h * head_w_stride;
282  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
283  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
284  float *k_h = k_out + (size_t)h * (size_t)aligned_head_dim;
285  float *v_h = v_out + (size_t)h * (size_t)aligned_head_dim;
286  for (int d = 0; d < head_dim; ++d) {
287  const float *wk_row = wk_h + (size_t)d * (size_t)aligned_embed_dim;
288  const float *wv_row = wv_h + (size_t)d * (size_t)aligned_embed_dim;
289  float k_sum = ck_dot_f32(ln1_row, wk_row, aligned_embed_dim);
290  float v_sum = ck_dot_f32(ln1_row, wv_row, aligned_embed_dim);
291  k_h[d] = k_sum + (bk_h ? bk_h[d] : 0.0f);
292  v_h[d] = v_sum + (bv_h ? bv_h[d] : 0.0f);
293  }
294  for (int d = head_dim; d < aligned_head_dim; ++d) {
295  k_h[d] = 0.0f;
296  v_h[d] = 0.0f;
297  }
298  }
299 }
300 
301 /*============================================================================
302  * Phase 2: RoPE In-Place (Q/K still hot in L1/L2)
303  *============================================================================*/
304 
305 /**
306  * @brief Apply RoPE to Q and K (in-place, from L1)
307  *
308  * Q and K are already in L1 from QKV projection.
309  * Just apply rotation in-place.
310  */
312  float *q, /* [num_heads * aligned_head_dim] - in L1 */
313  float *k, /* [num_kv_heads * aligned_head_dim] - in L1 */
314  const float *rope_cos,
315  const float *rope_sin,
316  int pos,
317  int num_heads,
318  int num_kv_heads,
319  int head_dim,
320  int aligned_head_dim)
321 {
322  if (!q || !k || !rope_cos || !rope_sin || head_dim <= 0 || aligned_head_dim <= 0) {
323  return;
324  }
325  if ((head_dim & 1) != 0) {
326  return;
327  }
328 
329  int half = head_dim / 2;
330  const float *cos_ptr = rope_cos + (size_t)pos * (size_t)half;
331  const float *sin_ptr = rope_sin + (size_t)pos * (size_t)half;
332 
333  for (int h = 0; h < num_heads; ++h) {
334  float *q_h = q + (size_t)h * (size_t)aligned_head_dim;
335  int i = 0;
336 #if defined(__AVX512F__)
337  for (; i + 16 <= half; i += 16) {
338  __m512 q0 = _mm512_loadu_ps(q_h + i);
339  __m512 q1 = _mm512_loadu_ps(q_h + i + half);
340  __m512 cos = _mm512_loadu_ps(cos_ptr + i);
341  __m512 sin = _mm512_loadu_ps(sin_ptr + i);
342 
343  __m512 q_rot0 = _mm512_sub_ps(_mm512_mul_ps(q0, cos), _mm512_mul_ps(q1, sin));
344  __m512 q_rot1 = _mm512_add_ps(_mm512_mul_ps(q0, sin), _mm512_mul_ps(q1, cos));
345 
346  _mm512_storeu_ps(q_h + i, q_rot0);
347  _mm512_storeu_ps(q_h + i + half, q_rot1);
348  }
349 #elif defined(__AVX__)
350  for (; i + 8 <= half; i += 8) {
351  __m256 q0 = _mm256_loadu_ps(q_h + i);
352  __m256 q1 = _mm256_loadu_ps(q_h + i + half);
353  __m256 cos = _mm256_loadu_ps(cos_ptr + i);
354  __m256 sin = _mm256_loadu_ps(sin_ptr + i);
355 
356  __m256 q_rot0 = _mm256_sub_ps(_mm256_mul_ps(q0, cos), _mm256_mul_ps(q1, sin));
357  __m256 q_rot1 = _mm256_add_ps(_mm256_mul_ps(q0, sin), _mm256_mul_ps(q1, cos));
358 
359  _mm256_storeu_ps(q_h + i, q_rot0);
360  _mm256_storeu_ps(q_h + i + half, q_rot1);
361  }
362 #endif
363  for (; i < half; ++i) {
364  float q0 = q_h[i];
365  float q1 = q_h[i + half];
366  float c = cos_ptr[i];
367  float s = sin_ptr[i];
368  q_h[i] = q0 * c - q1 * s;
369  q_h[i + half] = q0 * s + q1 * c;
370  }
371  for (int d = head_dim; d < aligned_head_dim; ++d) {
372  q_h[d] = 0.0f;
373  }
374  }
375 
376  for (int h = 0; h < num_kv_heads; ++h) {
377  float *k_h = k + (size_t)h * (size_t)aligned_head_dim;
378  int i = 0;
379 #if defined(__AVX512F__)
380  for (; i + 16 <= half; i += 16) {
381  __m512 k0 = _mm512_loadu_ps(k_h + i);
382  __m512 k1 = _mm512_loadu_ps(k_h + i + half);
383  __m512 cos = _mm512_loadu_ps(cos_ptr + i);
384  __m512 sin = _mm512_loadu_ps(sin_ptr + i);
385 
386  __m512 k_rot0 = _mm512_sub_ps(_mm512_mul_ps(k0, cos), _mm512_mul_ps(k1, sin));
387  __m512 k_rot1 = _mm512_add_ps(_mm512_mul_ps(k0, sin), _mm512_mul_ps(k1, cos));
388 
389  _mm512_storeu_ps(k_h + i, k_rot0);
390  _mm512_storeu_ps(k_h + i + half, k_rot1);
391  }
392 #elif defined(__AVX__)
393  for (; i + 8 <= half; i += 8) {
394  __m256 k0 = _mm256_loadu_ps(k_h + i);
395  __m256 k1 = _mm256_loadu_ps(k_h + i + half);
396  __m256 cos = _mm256_loadu_ps(cos_ptr + i);
397  __m256 sin = _mm256_loadu_ps(sin_ptr + i);
398 
399  __m256 k_rot0 = _mm256_sub_ps(_mm256_mul_ps(k0, cos), _mm256_mul_ps(k1, sin));
400  __m256 k_rot1 = _mm256_add_ps(_mm256_mul_ps(k0, sin), _mm256_mul_ps(k1, cos));
401 
402  _mm256_storeu_ps(k_h + i, k_rot0);
403  _mm256_storeu_ps(k_h + i + half, k_rot1);
404  }
405 #endif
406  for (; i < half; ++i) {
407  float k0 = k_h[i];
408  float k1 = k_h[i + half];
409  float c = cos_ptr[i];
410  float s = sin_ptr[i];
411  k_h[i] = k0 * c - k1 * s;
412  k_h[i + half] = k0 * s + k1 * c;
413  }
414  for (int d = head_dim; d < aligned_head_dim; ++d) {
415  k_h[d] = 0.0f;
416  }
417  }
418 }
419 
420 /*============================================================================
421  * Phase 3: Flash Attention with Online Softmax
422  *
423  * O, m, l stay in registers across all KV tiles.
424  * K/V stream from KV cache (in L2).
425  *============================================================================*/
426 
427 /**
428  * @brief Flash attention with online softmax (AVX version)
429  *
430  * Key insight: O, m, l stay in registers throughout!
431  * K/V tiles stream from L2 cache.
432  *
433  * @param o_out Output [num_heads * aligned_head_dim] - in registers/L1
434  * @param q Q tensor [num_heads * aligned_head_dim] - from L1
435  * @param kv_cache_k KV cache K [num_kv_heads * cache_capacity * aligned_head_dim]
436  * @param kv_cache_v KV cache V [num_kv_heads * cache_capacity * aligned_head_dim]
437  * @param num_heads Number of heads
438  * @param num_kv_heads Number of KV heads
439  * @param seq_len Current sequence length
440  * @param cache_capacity KV cache capacity (head stride)
441  * @param head_dim Head dimension
442  * @param aligned_head_dim Aligned head dimension
443  */
445  float *o_out,
446  const float *q,
447  const float *kv_cache_k,
448  const float *kv_cache_v,
449  int num_heads,
450  int num_kv_heads,
451  int seq_len,
452  int cache_capacity,
453  int head_dim,
454  int aligned_head_dim)
455 {
456  const int hd = head_dim;
457  const float scale = 1.0f / sqrtf((float)hd);
458  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
459 
460  for (int h = 0; h < num_heads; h++) {
461  const float *q_h = q + (size_t)h * (size_t)aligned_head_dim;
462  const int kv_idx = h % num_kv_heads;
463  const float *k_cache = kv_cache_k + (size_t)kv_idx * head_stride;
464  const float *v_cache = kv_cache_v + (size_t)kv_idx * head_stride;
465 
466  /* O, m, l in registers for this head */
467  float o_h[aligned_head_dim]; /* in L1 */
468  float m = -INFINITY; /* running max */
469  float l = 0.0f; /* running sum */
470 
471  /* Initialize O to zeros */
472  memset(o_h, 0, (size_t)aligned_head_dim * sizeof(float));
473 
474  /* Iterate over KV cache tiles */
475  for (int t = 0; t < seq_len; t += MEGA_KV_TILE) {
476  int tile_end = t + MEGA_KV_TILE;
477  if (tile_end > seq_len) tile_end = seq_len;
478  int tile_size = tile_end - t;
479 
480  /* Load K tile from L2 cache */
481  float k_tile[MEGA_KV_TILE * hd];
482  for (int i = 0; i < tile_size; i++) {
483  memcpy(k_tile + (size_t)i * (size_t)hd,
484  k_cache + (size_t)(t + i) * (size_t)aligned_head_dim,
485  (size_t)hd * sizeof(float));
486  }
487 
488  /* S_ij = Q @ K_tile.T / sqrt(d) - in registers */
489  float s_row[MEGA_KV_TILE];
490  for (int j = 0; j < tile_size; j++) {
491  s_row[j] = 0.0f;
492  for (int i = 0; i < hd; i++) {
493  s_row[j] += q_h[i] * k_tile[j * hd + i];
494  }
495  s_row[j] *= scale;
496  }
497 
498  /* Online softmax update */
499  float m_new = m;
500  for (int j = 0; j < tile_size; j++) {
501  if (s_row[j] > m_new) m_new = s_row[j];
502  }
503 
504  float l_new = 0.0f;
505  for (int j = 0; j < tile_size; j++) {
506  float p = expf(s_row[j] - m_new);
507  s_row[j] = p;
508  l_new += p;
509  }
510 
511  /* Scale O by exp(m - m_new) and add P @ V */
512  float exp_m_diff = expf(m - m_new);
513  for (int i = 0; i < hd; i++) {
514  o_h[i] *= exp_m_diff;
515  }
516 
517  /* Load V tile and accumulate */
518  for (int j = 0; j < tile_size; j++) {
519  float p = s_row[j];
520  for (int i = 0; i < hd; i++) {
521  o_h[i] += p * v_cache[(size_t)(t + j) * (size_t)aligned_head_dim + (size_t)i];
522  }
523  }
524 
525  l = l * exp_m_diff + l_new;
526  m = m_new;
527  }
528 
529  /* Normalize by l */
530  for (int i = 0; i < hd; i++) {
531  o_h[i] /= l;
532  }
533  for (int i = hd; i < aligned_head_dim; ++i) {
534  o_h[i] = 0.0f;
535  }
536 
537  /* Store O - still in L1, goes to output projection */
538  memcpy(o_out + (size_t)h * (size_t)aligned_head_dim,
539  o_h,
540  (size_t)aligned_head_dim * sizeof(float));
541  }
542 }
543 
544 /*============================================================================
545  * Phase 4: Full Mega-Fused Attention Decode
546  *
547  * RMSNorm → QKV → RoPE → Flash Attn → OutProj + Residual
548  * All intermediates in L1/L2, single DRAM round-trip.
549  *============================================================================*/
550 
552  const float *attn_token,
553  const float *wo,
554  const float *bo,
555  const float *residual,
556  float *output,
557  int embed_dim,
558  int aligned_embed_dim,
559  int num_heads,
560  int head_dim,
561  int aligned_head_dim)
562 {
563  if (!attn_token || !wo || !output) {
564  return;
565  }
566 
567  const size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
568 
569  for (int j = 0; j < embed_dim; ++j) {
570  float sum = bo ? bo[j] : 0.0f;
571  for (int h = 0; h < num_heads; ++h) {
572  const float *o_h = attn_token + (size_t)h * (size_t)aligned_head_dim;
573  const float *wo_row = wo + (size_t)h * head_weight_stride + (size_t)j * (size_t)aligned_head_dim;
574  sum += ck_dot_f32(o_h, wo_row, head_dim);
575  }
576  output[j] = sum + (residual ? residual[j] : 0.0f);
577  }
578 
579  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
580  output[j] = 0.0f;
581  }
582 }
583 
584 /**
585  * @brief Full mega-fused attention for decode
586  *
587  * RMSNorm → QKV → RoPE → Flash Attn → OutProj + Residual
588  */
590  float *output, /* [aligned_embed_dim] */
591  const float *input, /* [aligned_embed_dim] */
592  const float *residual, /* [aligned_embed_dim] */
593  const float *ln1_gamma,
594  const float *wq, const float *bq,
595  const float *wk, const float *bk,
596  const float *wv, const float *bv,
597  const float *wo, const float *bo,
598  float *kv_cache_k,
599  float *kv_cache_v,
600  const float *rope_cos,
601  const float *rope_sin,
602  int pos,
603  int embed_dim,
604  int aligned_embed_dim,
605  int num_heads,
606  int num_kv_heads,
607  int head_dim,
608  int aligned_head_dim,
609  int cache_capacity,
610  float eps)
611 {
612  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
613  !kv_cache_k || !kv_cache_v) {
614  return;
615  }
616  if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0 ||
617  num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
618  return;
619  }
620  if (pos < 0 || pos >= cache_capacity) {
621  return;
622  }
623  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
624  return;
625  }
626 
627  const size_t q_elems = (size_t)num_heads * (size_t)aligned_head_dim;
628  const size_t kv_elems = (size_t)num_kv_heads * (size_t)aligned_head_dim;
629 
630  float q_stack[MEGA_STACK_MAX];
631  float k_stack[MEGA_STACK_MAX];
632  float v_stack[MEGA_STACK_MAX];
633  float o_stack[MEGA_STACK_MAX];
634 
635  float *q = q_stack;
636  float *k = k_stack;
637  float *v = v_stack;
638  float *o = o_stack;
639 
640  int free_q = 0;
641  int free_k = 0;
642  int free_v = 0;
643  int free_o = 0;
644 
645  if (q_elems > MEGA_STACK_MAX) {
646  q = (float *)malloc(q_elems * sizeof(float));
647  if (!q) {
648  return;
649  }
650  free_q = 1;
651  }
652  if (kv_elems > MEGA_STACK_MAX) {
653  k = (float *)malloc(kv_elems * sizeof(float));
654  if (!k) {
655  if (free_q) free(q);
656  return;
657  }
658  v = (float *)malloc(kv_elems * sizeof(float));
659  if (!v) {
660  if (free_q) free(q);
661  free(k);
662  return;
663  }
664  free_k = 1;
665  free_v = 1;
666  }
667  if (q_elems > MEGA_STACK_MAX) {
668  o = (float *)malloc(q_elems * sizeof(float));
669  if (!o) {
670  if (free_q) free(q);
671  if (free_k) free(k);
672  if (free_v) free(v);
673  return;
674  }
675  free_o = 1;
676  }
677 
678  mega_fuse_rmsnorm_qkv_avx(q, k, v, input, ln1_gamma,
679  wq, bq, wk, bk, wv, bv,
680  embed_dim, aligned_embed_dim,
681  num_heads, num_kv_heads,
682  head_dim, aligned_head_dim, eps);
683 
684  if (rope_cos && rope_sin) {
685  mega_fuse_rope_inplace_avx(q, k, rope_cos, rope_sin, pos,
686  num_heads, num_kv_heads,
687  head_dim, aligned_head_dim);
688  }
689 
691  kv_cache_k, kv_cache_v,
692  num_kv_heads, pos,
693  cache_capacity,
694  head_dim, aligned_head_dim);
695 
696  mega_fuse_flash_attention_avx(o, q, kv_cache_k, kv_cache_v,
697  num_heads, num_kv_heads,
698  pos + 1, cache_capacity,
699  head_dim, aligned_head_dim);
700 
701  mega_fuse_output_proj_residual(o, wo, bo, residual, output,
702  embed_dim, aligned_embed_dim,
703  num_heads, head_dim, aligned_head_dim);
704 
705  if (free_q) free(q);
706  if (free_k) free(k);
707  if (free_v) free(v);
708  if (free_o) free(o);
709 }
CPU feature detection and dispatch macros.
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
Quantization block structures for weight-only quantization.
static void mega_fuse_output_proj_residual(const float *attn_token, const float *wo, const float *bo, const float *residual, float *output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)
void mega_fuse_flash_attention_avx(float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)
Flash attention with online softmax (AVX version)
static float ck_dot_f32(const float *a, const float *b, int len)
void mega_fuse_rmsnorm_qkv_avx(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)
Fused RMSNorm + QKV for decode (single token)
void mega_fuse_rope_inplace_avx(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)
Apply RoPE to Q and K (in-place, from L1)
void mega_fused_attention_decode(float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
Full mega-fused attention for decode.
#define MEGA_STACK_MAX
#define MEGA_KV_TILE