← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_flash_true.c
Go to the documentation of this file.
1 /**
2  * @file attention_flash_true.c
3  * @brief Flash-style attention (online softmax, causal, streaming)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Layout:
15  * Q/K/V/Out: [T, H, D_h] contiguous
16  *
17  * Causal alignment:
18  * Queries are assumed to correspond to the last T_q positions in the KV cache.
19  * This makes T_q == T_k behave like standard causal prefill, and T_q == 1
20  * behave like decode over a full KV cache.
21  *
22  * Notes:
23  * - This is O(T_k) per query head; it avoids materializing the score matrix.
24  * - SIMD paths are provided for AVX-512 and AVX.
25  */
26 
27 #include <math.h>
28 #include <stddef.h>
29 #include <stdint.h>
30 
31 #if defined(__AVX512F__)
32 #include <immintrin.h>
33 #endif
34 
35 #if defined(__AVX__) && !defined(__AVX512F__)
36 #include <immintrin.h>
37 #endif
38 
39 #ifndef CK_FLASH_ATTN_TILE_K
40 #define CK_FLASH_ATTN_TILE_K 32
41 #endif
42 
43 #ifndef CK_FLASH_ATTN_FAST_EXP
44 #define CK_FLASH_ATTN_FAST_EXP 0
45 #endif
46 
47 static inline float ck_fast_expf(float x) {
48  const float max_val = 88.0f;
49  const float min_val = -88.0f;
50  if (x > max_val) {
51  x = max_val;
52  } else if (x < min_val) {
53  x = min_val;
54  }
55 
56  const float log2e = 1.4426950408889634f;
57  float z = x * log2e;
58  float zf = nearbyintf(z);
59  float f = z - zf;
60 
61  const float c0 = 1.0f;
62  const float c1 = 0.6931471805599453f;
63  const float c2 = 0.2402265069591007f;
64  const float c3 = 0.05550410866482158f;
65  const float c4 = 0.009618129107628478f;
66 
67  float poly = ((c4 * f + c3) * f + c2) * f + c1;
68  poly = poly * f + c0;
69 
70  int32_t zi = (int32_t)zf + 127;
71  uint32_t bits = (uint32_t)zi << 23;
72  union {
73  uint32_t i;
74  float f;
75  } u;
76  u.i = bits;
77  return poly * u.f;
78 }
79 
80 static inline float ck_expf(float x) {
81 #if CK_FLASH_ATTN_FAST_EXP
82  return ck_fast_expf(x);
83 #else
84  return expf(x);
85 #endif
86 }
87 
88 static inline int ck_flash_attn_tile_k(int D_h) {
89  int tile = CK_FLASH_ATTN_TILE_K;
90  if (D_h > 128) {
91  tile = CK_FLASH_ATTN_TILE_K / 4;
92  } else if (D_h > 64) {
93  tile = CK_FLASH_ATTN_TILE_K / 2;
94  }
95 
96  if (CK_FLASH_ATTN_TILE_K >= 8 && tile < 8) {
97  tile = 8;
98  }
99  if (tile > CK_FLASH_ATTN_TILE_K) {
100  tile = CK_FLASH_ATTN_TILE_K;
101  }
102  if (tile < 1) {
103  tile = 1;
104  }
105  return tile;
106 }
107 
109  return ck_flash_attn_tile_k(D_h);
110 }
111 
113 #if CK_FLASH_ATTN_FAST_EXP
114 #if defined(__AVX512F__)
115  return 512;
116 #elif defined(__AVX__)
117  return 256;
118 #else
119  return 0;
120 #endif
121 #else
122  return 0;
123 #endif
124 }
125 
126 static inline int max_k_for_query(int t_q, int T_q, int T_k) {
127  int q_pos_offset = (T_k > T_q) ? (T_k - T_q) : 0;
128  int max_k = q_pos_offset + t_q;
129  if (max_k >= T_k) {
130  max_k = T_k - 1;
131  }
132  return max_k;
133 }
134 
135 /* ============================================================================
136  * SCALAR REFERENCE IMPLEMENTATION
137  * ============================================================================ */
138 
139 /**
140  * @brief Scalar flash-style attention (online softmax)
141  */
143  float *out,
144  const float *q,
145  const float *k,
146  const float *v,
147  int T_q,
148  int T_k,
149  int H,
150  int D_h,
151  float scale)
152 {
153  const int total = T_q * H;
154  const size_t stride = (size_t)H * (size_t)D_h;
155  const int tile_k = ck_flash_attn_tile_k(D_h);
156 
157  for (int idx = 0; idx < total; ++idx) {
158  const int t_q = idx / H;
159  const int h = idx - t_q * H;
160  const int max_k = max_k_for_query(t_q, T_q, T_k);
161 
162  const float *q_head = q + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
163  float *out_head = out + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
164  const float *k_base = k + (size_t)h * (size_t)D_h;
165  const float *v_base = v + (size_t)h * (size_t)D_h;
166 
167  for (int d = 0; d < D_h; ++d) {
168  out_head[d] = 0.0f;
169  }
170 
171  float m = -INFINITY;
172  float s = 0.0f;
173 
174  float scores[CK_FLASH_ATTN_TILE_K];
175 
176  for (int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
177  int blk_len = max_k - t_k0 + 1;
178  if (blk_len > tile_k) {
179  blk_len = tile_k;
180  }
181 
182  float m_block = -INFINITY;
183  for (int bi = 0; bi < blk_len; ++bi) {
184  const int t_k = t_k0 + bi;
185  const float *k_head = k_base + (size_t)t_k * stride;
186 
187  float dot = 0.0f;
188  for (int d = 0; d < D_h; ++d) {
189  dot += q_head[d] * k_head[d];
190  }
191 
192  float score = dot * scale;
193  scores[bi] = score;
194  if (score > m_block) {
195  m_block = score;
196  }
197  }
198 
199  if (m_block > m) {
200  float scale_old = (m == -INFINITY) ? 0.0f : ck_expf(m - m_block);
201  s *= scale_old;
202  for (int d = 0; d < D_h; ++d) {
203  out_head[d] *= scale_old;
204  }
205  m = m_block;
206  }
207 
208  for (int bi = 0; bi < blk_len; ++bi) {
209  const int t_k = t_k0 + bi;
210  const float *v_head = v_base + (size_t)t_k * stride;
211  float w = ck_expf(scores[bi] - m);
212  s += w;
213  for (int d = 0; d < D_h; ++d) {
214  out_head[d] += w * v_head[d];
215  }
216  }
217  }
218 
219  if (s > 0.0f) {
220  float inv_s = 1.0f / s;
221  for (int d = 0; d < D_h; ++d) {
222  out_head[d] *= inv_s;
223  }
224  } else {
225  for (int d = 0; d < D_h; ++d) {
226  out_head[d] = 0.0f;
227  }
228  }
229  }
230 }
231 
232 #if defined(__AVX512F__)
233 
234 /* ============================================================================
235  * AVX-512 IMPLEMENTATION (16 floats per vector)
236  * ============================================================================ */
237 
238 #if CK_FLASH_ATTN_FAST_EXP
239 static inline __m512 ck_fast_exp512_ps(__m512 x) {
240  const __m512 max_val = _mm512_set1_ps(88.0f);
241  const __m512 min_val = _mm512_set1_ps(-88.0f);
242  x = _mm512_min_ps(x, max_val);
243  x = _mm512_max_ps(x, min_val);
244 
245  const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
246  __m512 z = _mm512_mul_ps(x, log2e);
247  __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
248  __m512 f = _mm512_sub_ps(z, zf);
249 
250  const __m512 c0 = _mm512_set1_ps(1.0f);
251  const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
252  const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
253  const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
254  const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
255 
256 #if defined(__FMA__)
257  __m512 poly = _mm512_fmadd_ps(c4, f, c3);
258  poly = _mm512_fmadd_ps(poly, f, c2);
259  poly = _mm512_fmadd_ps(poly, f, c1);
260  poly = _mm512_fmadd_ps(poly, f, c0);
261 #else
262  __m512 poly = _mm512_add_ps(_mm512_mul_ps(c4, f), c3);
263  poly = _mm512_add_ps(_mm512_mul_ps(poly, f), c2);
264  poly = _mm512_add_ps(_mm512_mul_ps(poly, f), c1);
265  poly = _mm512_add_ps(_mm512_mul_ps(poly, f), c0);
266 #endif
267 
268  __m512i zi = _mm512_cvtps_epi32(zf);
269  zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
270  zi = _mm512_slli_epi32(zi, 23);
271  __m512 pow2 = _mm512_castsi512_ps(zi);
272  return _mm512_mul_ps(poly, pow2);
273 }
274 #endif
275 
276 static inline float ck_dot_f32_avx512(const float *q, const float *k, int D_h) {
277  __m512 sum0 = _mm512_setzero_ps();
278  __m512 sum1 = _mm512_setzero_ps();
279 
280  int d = 0;
281  for (; d + 32 <= D_h; d += 32) {
282  __m512 q0 = _mm512_loadu_ps(q + d);
283  __m512 k0 = _mm512_loadu_ps(k + d);
284  __m512 q1 = _mm512_loadu_ps(q + d + 16);
285  __m512 k1 = _mm512_loadu_ps(k + d + 16);
286  sum0 = _mm512_fmadd_ps(q0, k0, sum0);
287  sum1 = _mm512_fmadd_ps(q1, k1, sum1);
288  }
289  for (; d + 16 <= D_h; d += 16) {
290  __m512 q0 = _mm512_loadu_ps(q + d);
291  __m512 k0 = _mm512_loadu_ps(k + d);
292  sum0 = _mm512_fmadd_ps(q0, k0, sum0);
293  }
294 
295  sum0 = _mm512_add_ps(sum0, sum1);
296  float dot = _mm512_reduce_add_ps(sum0);
297  for (; d < D_h; ++d) {
298  dot += q[d] * k[d];
299  }
300  return dot;
301 }
302 
303 /**
304  * @brief AVX-512 optimized flash attention decode
305  */
306 static void attention_flash_decode_avx512(
307  float *out,
308  const float *q,
309  const float *k,
310  const float *v,
311  int T_q,
312  int T_k,
313  int H,
314  int D_h,
315  float scale)
316 {
317  const int total = T_q * H;
318  const size_t stride = (size_t)H * (size_t)D_h;
319  const int tile_k = ck_flash_attn_tile_k(D_h);
320 
321  for (int idx = 0; idx < total; ++idx) {
322  const int t_q = idx / H;
323  const int h = idx - t_q * H;
324  const int max_k = max_k_for_query(t_q, T_q, T_k);
325 
326  const float *q_head = q + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
327  float *out_head = out + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
328  const float *k_base = k + (size_t)h * (size_t)D_h;
329  const float *v_base = v + (size_t)h * (size_t)D_h;
330 
331  int d = 0;
332  for (; d + 16 <= D_h; d += 16) {
333  _mm512_storeu_ps(out_head + d, _mm512_setzero_ps());
334  }
335  for (; d < D_h; ++d) {
336  out_head[d] = 0.0f;
337  }
338 
339  float m = -INFINITY;
340  float s = 0.0f;
341 
342  float scores[CK_FLASH_ATTN_TILE_K];
343 
344  for (int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
345  int blk_len = max_k - t_k0 + 1;
346  if (blk_len > tile_k) {
347  blk_len = tile_k;
348  }
349 
350  float m_block = -INFINITY;
351  for (int bi = 0; bi < blk_len; ++bi) {
352  const int t_k = t_k0 + bi;
353  const float *k_head = k_base + (size_t)t_k * stride;
354 
355  float dot = ck_dot_f32_avx512(q_head, k_head, D_h);
356 
357  float score = dot * scale;
358  scores[bi] = score;
359  if (score > m_block) {
360  m_block = score;
361  }
362  }
363 
364  if (m_block > m) {
365  float scale_old = (m == -INFINITY) ? 0.0f : ck_expf(m - m_block);
366  s *= scale_old;
367  __m512 scale_old_vec = _mm512_set1_ps(scale_old);
368  d = 0;
369  for (; d + 16 <= D_h; d += 16) {
370  __m512 out_v = _mm512_loadu_ps(out_head + d);
371  _mm512_storeu_ps(out_head + d, _mm512_mul_ps(out_v, scale_old_vec));
372  }
373  for (; d < D_h; ++d) {
374  out_head[d] *= scale_old;
375  }
376  m = m_block;
377  }
378 
379 #if CK_FLASH_ATTN_FAST_EXP
380  int bi_vec = 0;
381  __m512 m_vec = _mm512_set1_ps(m);
382  for (; bi_vec + 16 <= blk_len; bi_vec += 16) {
383  __m512 s_vec = _mm512_loadu_ps(scores + bi_vec);
384  s_vec = _mm512_sub_ps(s_vec, m_vec);
385  __m512 w_vec = ck_fast_exp512_ps(s_vec);
386  _mm512_storeu_ps(scores + bi_vec, w_vec);
387  }
388  for (; bi_vec < blk_len; ++bi_vec) {
389  scores[bi_vec] = ck_fast_expf(scores[bi_vec] - m);
390  }
391 #endif
392 
393  for (int bi = 0; bi < blk_len; ++bi) {
394  const int t_k = t_k0 + bi;
395  const float *v_head = v_base + (size_t)t_k * stride;
396 #if CK_FLASH_ATTN_FAST_EXP
397  float w = scores[bi];
398 #else
399  float w = ck_expf(scores[bi] - m);
400 #endif
401  s += w;
402 
403  __m512 w_vec = _mm512_set1_ps(w);
404  d = 0;
405  for (; d + 16 <= D_h; d += 16) {
406  __m512 out_v = _mm512_loadu_ps(out_head + d);
407  __m512 v_v = _mm512_loadu_ps(v_head + d);
408  out_v = _mm512_fmadd_ps(w_vec, v_v, out_v);
409  _mm512_storeu_ps(out_head + d, out_v);
410  }
411  for (; d < D_h; ++d) {
412  out_head[d] += w * v_head[d];
413  }
414  }
415  }
416 
417  if (s > 0.0f) {
418  float inv_s = 1.0f / s;
419  __m512 inv_s_vec = _mm512_set1_ps(inv_s);
420  d = 0;
421  for (; d + 16 <= D_h; d += 16) {
422  __m512 out_v = _mm512_loadu_ps(out_head + d);
423  _mm512_storeu_ps(out_head + d, _mm512_mul_ps(out_v, inv_s_vec));
424  }
425  for (; d < D_h; ++d) {
426  out_head[d] *= inv_s;
427  }
428  } else {
429  for (int d0 = 0; d0 < D_h; ++d0) {
430  out_head[d0] = 0.0f;
431  }
432  }
433  }
434 }
435 
436 #endif // __AVX512F__
437 
438 #if defined(__AVX__) && !defined(__AVX512F__)
439 
440 /* ============================================================================
441  * AVX IMPLEMENTATION (8 floats per vector)
442  * ============================================================================ */
443 
444 #if CK_FLASH_ATTN_FAST_EXP
445 static inline __m256 ck_pow2_256_ps(__m256 zf) {
446  __m128 z0 = _mm256_castps256_ps128(zf);
447  __m128 z1 = _mm256_extractf128_ps(zf, 1);
448 
449  __m128i i0 = _mm_cvtps_epi32(z0);
450  __m128i i1 = _mm_cvtps_epi32(z1);
451  i0 = _mm_add_epi32(i0, _mm_set1_epi32(127));
452  i1 = _mm_add_epi32(i1, _mm_set1_epi32(127));
453  i0 = _mm_slli_epi32(i0, 23);
454  i1 = _mm_slli_epi32(i1, 23);
455 
456  __m128 f0 = _mm_castsi128_ps(i0);
457  __m128 f1 = _mm_castsi128_ps(i1);
458  __m256 out = _mm256_castps128_ps256(f0);
459  return _mm256_insertf128_ps(out, f1, 1);
460 }
461 
462 static inline __m256 ck_fast_exp256_ps(__m256 x) {
463  const __m256 max_val = _mm256_set1_ps(88.0f);
464  const __m256 min_val = _mm256_set1_ps(-88.0f);
465  x = _mm256_min_ps(x, max_val);
466  x = _mm256_max_ps(x, min_val);
467 
468  const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
469  __m256 z = _mm256_mul_ps(x, log2e);
470  __m256 zf = _mm256_round_ps(z, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
471  __m256 f = _mm256_sub_ps(z, zf);
472 
473  const __m256 c0 = _mm256_set1_ps(1.0f);
474  const __m256 c1 = _mm256_set1_ps(0.6931471805599453f);
475  const __m256 c2 = _mm256_set1_ps(0.2402265069591007f);
476  const __m256 c3 = _mm256_set1_ps(0.05550410866482158f);
477  const __m256 c4 = _mm256_set1_ps(0.009618129107628478f);
478 
479 #if defined(__FMA__)
480  __m256 poly = _mm256_fmadd_ps(c4, f, c3);
481  poly = _mm256_fmadd_ps(poly, f, c2);
482  poly = _mm256_fmadd_ps(poly, f, c1);
483  poly = _mm256_fmadd_ps(poly, f, c0);
484 #else
485  __m256 poly = _mm256_add_ps(_mm256_mul_ps(c4, f), c3);
486  poly = _mm256_add_ps(_mm256_mul_ps(poly, f), c2);
487  poly = _mm256_add_ps(_mm256_mul_ps(poly, f), c1);
488  poly = _mm256_add_ps(_mm256_mul_ps(poly, f), c0);
489 #endif
490 
491  __m256 pow2 = ck_pow2_256_ps(zf);
492  return _mm256_mul_ps(poly, pow2);
493 }
494 #endif
495 
496 static inline float hsum256_ps(__m256 v) {
497  __m128 lo = _mm256_castps256_ps128(v);
498  __m128 hi = _mm256_extractf128_ps(v, 1);
499  __m128 sum128 = _mm_add_ps(lo, hi);
500  __m128 shuf = _mm_movehdup_ps(sum128);
501  __m128 sums = _mm_add_ps(sum128, shuf);
502  shuf = _mm_movehl_ps(shuf, sums);
503  sums = _mm_add_ps(shuf, sums);
504  return _mm_cvtss_f32(sums);
505 }
506 
507 static inline float ck_dot_f32_avx(const float *q, const float *k, int D_h) {
508  __m256 sum0 = _mm256_setzero_ps();
509  __m256 sum1 = _mm256_setzero_ps();
510 
511  int d = 0;
512  for (; d + 16 <= D_h; d += 16) {
513  __m256 q0 = _mm256_loadu_ps(q + d);
514  __m256 k0 = _mm256_loadu_ps(k + d);
515  __m256 q1 = _mm256_loadu_ps(q + d + 8);
516  __m256 k1 = _mm256_loadu_ps(k + d + 8);
517  #if defined(__FMA__)
518  sum0 = _mm256_fmadd_ps(q0, k0, sum0);
519  sum1 = _mm256_fmadd_ps(q1, k1, sum1);
520  #else
521  sum0 = _mm256_add_ps(sum0, _mm256_mul_ps(q0, k0));
522  sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(q1, k1));
523  #endif
524  }
525  for (; d + 8 <= D_h; d += 8) {
526  __m256 q0 = _mm256_loadu_ps(q + d);
527  __m256 k0 = _mm256_loadu_ps(k + d);
528  #if defined(__FMA__)
529  sum0 = _mm256_fmadd_ps(q0, k0, sum0);
530  #else
531  sum0 = _mm256_add_ps(sum0, _mm256_mul_ps(q0, k0));
532  #endif
533  }
534 
535  __m256 sum = _mm256_add_ps(sum0, sum1);
536  float dot = hsum256_ps(sum);
537  for (; d < D_h; ++d) {
538  dot += q[d] * k[d];
539  }
540  return dot;
541 }
542 
543 static void attention_flash_decode_avx(
544  float *out,
545  const float *q,
546  const float *k,
547  const float *v,
548  int T_q,
549  int T_k,
550  int H,
551  int D_h,
552  float scale)
553 {
554  const int total = T_q * H;
555  const size_t stride = (size_t)H * (size_t)D_h;
556  const int tile_k = ck_flash_attn_tile_k(D_h);
557 
558  for (int idx = 0; idx < total; ++idx) {
559  const int t_q = idx / H;
560  const int h = idx - t_q * H;
561  const int max_k = max_k_for_query(t_q, T_q, T_k);
562 
563  const float *q_head = q + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
564  float *out_head = out + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
565  const float *k_base = k + (size_t)h * (size_t)D_h;
566  const float *v_base = v + (size_t)h * (size_t)D_h;
567 
568  int d = 0;
569  for (; d + 8 <= D_h; d += 8) {
570  _mm256_storeu_ps(out_head + d, _mm256_setzero_ps());
571  }
572  for (; d < D_h; ++d) {
573  out_head[d] = 0.0f;
574  }
575 
576  float m = -INFINITY;
577  float s = 0.0f;
578 
579  float scores[CK_FLASH_ATTN_TILE_K];
580 
581  for (int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
582  int blk_len = max_k - t_k0 + 1;
583  if (blk_len > tile_k) {
584  blk_len = tile_k;
585  }
586 
587  float m_block = -INFINITY;
588  for (int bi = 0; bi < blk_len; ++bi) {
589  const int t_k = t_k0 + bi;
590  const float *k_head = k_base + (size_t)t_k * stride;
591 
592  float dot = ck_dot_f32_avx(q_head, k_head, D_h);
593 
594  float score = dot * scale;
595  scores[bi] = score;
596  if (score > m_block) {
597  m_block = score;
598  }
599  }
600 
601  if (m_block > m) {
602  float scale_old = (m == -INFINITY) ? 0.0f : ck_expf(m - m_block);
603  s *= scale_old;
604  __m256 scale_old_vec = _mm256_set1_ps(scale_old);
605  d = 0;
606  for (; d + 8 <= D_h; d += 8) {
607  __m256 out_v = _mm256_loadu_ps(out_head + d);
608  _mm256_storeu_ps(out_head + d, _mm256_mul_ps(out_v, scale_old_vec));
609  }
610  for (; d < D_h; ++d) {
611  out_head[d] *= scale_old;
612  }
613  m = m_block;
614  }
615 
616 #if CK_FLASH_ATTN_FAST_EXP
617  int bi_vec = 0;
618  __m256 m_vec = _mm256_set1_ps(m);
619  for (; bi_vec + 8 <= blk_len; bi_vec += 8) {
620  __m256 s_vec = _mm256_loadu_ps(scores + bi_vec);
621  s_vec = _mm256_sub_ps(s_vec, m_vec);
622  __m256 w_vec = ck_fast_exp256_ps(s_vec);
623  _mm256_storeu_ps(scores + bi_vec, w_vec);
624  }
625  for (; bi_vec < blk_len; ++bi_vec) {
626  scores[bi_vec] = ck_fast_expf(scores[bi_vec] - m);
627  }
628 #endif
629 
630  for (int bi = 0; bi < blk_len; ++bi) {
631  const int t_k = t_k0 + bi;
632  const float *v_head = v_base + (size_t)t_k * stride;
633 #if CK_FLASH_ATTN_FAST_EXP
634  float w = scores[bi];
635 #else
636  float w = ck_expf(scores[bi] - m);
637 #endif
638  s += w;
639 
640  __m256 w_vec = _mm256_set1_ps(w);
641  d = 0;
642  for (; d + 8 <= D_h; d += 8) {
643  __m256 out_v = _mm256_loadu_ps(out_head + d);
644  __m256 v_v = _mm256_loadu_ps(v_head + d);
645  #if defined(__FMA__)
646  out_v = _mm256_fmadd_ps(w_vec, v_v, out_v);
647  #else
648  out_v = _mm256_add_ps(out_v, _mm256_mul_ps(w_vec, v_v));
649  #endif
650  _mm256_storeu_ps(out_head + d, out_v);
651  }
652  for (; d < D_h; ++d) {
653  out_head[d] += w * v_head[d];
654  }
655  }
656  }
657 
658  if (s > 0.0f) {
659  float inv_s = 1.0f / s;
660  __m256 inv_s_vec = _mm256_set1_ps(inv_s);
661  d = 0;
662  for (; d + 8 <= D_h; d += 8) {
663  __m256 out_v = _mm256_loadu_ps(out_head + d);
664  _mm256_storeu_ps(out_head + d, _mm256_mul_ps(out_v, inv_s_vec));
665  }
666  for (; d < D_h; ++d) {
667  out_head[d] *= inv_s;
668  }
669  } else {
670  for (int d0 = 0; d0 < D_h; ++d0) {
671  out_head[d0] = 0.0f;
672  }
673  }
674  }
675 }
676 
677 #endif // __AVX__
678 
679 /* ============================================================================
680  * DISPATCHER FUNCTION
681  * ============================================================================ */
682 
683 /**
684  * @brief Main flash attention function with SIMD dispatch
685  *
686  * @param out Output [T_q, H, D_h]
687  * @param q Query [T_q, H, D_h]
688  * @param k Key [T_k, H, D_h]
689  * @param v Value [T_k, H, D_h]
690  * @param T_q Number of query tokens (1 for decode)
691  * @param T_k Number of key/value tokens (context length)
692  * @param H Number of heads
693  * @param D_h Head dimension
694  * @param scale 1/sqrt(D_h)
695  */
697  float *out,
698  const float *q,
699  const float *k,
700  const float *v,
701  int T_q,
702  int T_k,
703  int H,
704  int D_h,
705  float scale)
706 {
707  if (!out || !q || !k || !v) {
708  return;
709  }
710  if (T_q <= 0 || T_k <= 0 || H <= 0 || D_h <= 0) {
711  return;
712  }
713 
714  // Dispatch based on CPU features
715 #if defined(__AVX512F__)
716  attention_flash_decode_avx512(out, q, k, v, T_q, T_k, H, D_h, scale);
717 #elif defined(__AVX__) && !defined(__AVX512F__)
718  attention_flash_decode_avx(out, q, k, v, T_q, T_k, H, D_h, scale);
719 #else
720  attention_flash_decode_scalar(out, q, k, v, T_q, T_k, H, D_h, scale);
721 #endif
722 }
723 
724 /* ============================================================================
725  * UTILITY FUNCTIONS
726  * ============================================================================ */
727 
728 /**
729  * @brief Initialize flash attention buffers
730  */
731 void attention_flash_init(int max_context, int max_heads, int max_head_dim) {
732  // For future optimization: pre-allocate scratch buffers
733  // Currently using stack/heap allocation
734 }
735 
736 /**
737  * @brief Clean up flash attention resources
738  */
740  // For future optimization: free pre-allocated buffers
741 }
int ck_flash_attn_choose_tile_k(int D_h)
void attention_flash_cleanup(void)
Clean up flash attention resources.
static void attention_flash_decode_scalar(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Scalar flash-style attention (online softmax)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
static float ck_fast_expf(float x)
static int max_k_for_query(int t_q, int T_q, int T_k)
void attention_flash_init(int max_context, int max_heads, int max_head_dim)
Initialize flash attention buffers.
static int ck_flash_attn_tile_k(int D_h)
int ck_flash_attn_fast_exp_kind(void)
static float ck_expf(float x)
#define CK_FLASH_ATTN_TILE_K
int32_t float * score
Definition: tokenizer.h:327