31 #if defined(__AVX512F__)
32 #include <immintrin.h>
35 #if defined(__AVX__) && !defined(__AVX512F__)
36 #include <immintrin.h>
39 #ifndef CK_FLASH_ATTN_TILE_K
40 #define CK_FLASH_ATTN_TILE_K 32
43 #ifndef CK_FLASH_ATTN_FAST_EXP
44 #define CK_FLASH_ATTN_FAST_EXP 0
48 const float max_val = 88.0f;
49 const float min_val = -88.0f;
52 }
else if (x < min_val) {
56 const float log2e = 1.4426950408889634f;
58 float zf = nearbyintf(z);
61 const float c0 = 1.0f;
62 const float c1 = 0.6931471805599453f;
63 const float c2 = 0.2402265069591007f;
64 const float c3 = 0.05550410866482158f;
65 const float c4 = 0.009618129107628478f;
67 float poly = ((c4 * f + c3) * f + c2) * f + c1;
70 int32_t zi = (int32_t)zf + 127;
71 uint32_t bits = (uint32_t)zi << 23;
81 #if CK_FLASH_ATTN_FAST_EXP
92 }
else if (D_h > 64) {
113 #if CK_FLASH_ATTN_FAST_EXP
114 #if defined(__AVX512F__)
116 #elif defined(__AVX__)
127 int q_pos_offset = (T_k > T_q) ? (T_k - T_q) : 0;
128 int max_k = q_pos_offset + t_q;
153 const int total = T_q * H;
154 const size_t stride = (size_t)H * (
size_t)D_h;
157 for (
int idx = 0; idx < total; ++idx) {
158 const int t_q = idx / H;
159 const int h = idx - t_q * H;
162 const float *q_head = q + (size_t)t_q * stride + (
size_t)h * (size_t)D_h;
163 float *out_head = out + (size_t)t_q * stride + (
size_t)h * (size_t)D_h;
164 const float *k_base = k + (size_t)h * (
size_t)D_h;
165 const float *v_base = v + (size_t)h * (
size_t)D_h;
167 for (
int d = 0; d < D_h; ++d) {
176 for (
int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
177 int blk_len = max_k - t_k0 + 1;
178 if (blk_len > tile_k) {
182 float m_block = -INFINITY;
183 for (
int bi = 0; bi < blk_len; ++bi) {
184 const int t_k = t_k0 + bi;
185 const float *k_head = k_base + (size_t)t_k * stride;
188 for (
int d = 0; d < D_h; ++d) {
189 dot += q_head[d] * k_head[d];
192 float score = dot * scale;
194 if (
score > m_block) {
200 float scale_old = (m == -INFINITY) ? 0.0f :
ck_expf(m - m_block);
202 for (
int d = 0; d < D_h; ++d) {
203 out_head[d] *= scale_old;
208 for (
int bi = 0; bi < blk_len; ++bi) {
209 const int t_k = t_k0 + bi;
210 const float *v_head = v_base + (size_t)t_k * stride;
211 float w =
ck_expf(scores[bi] - m);
213 for (
int d = 0; d < D_h; ++d) {
214 out_head[d] += w * v_head[d];
220 float inv_s = 1.0f / s;
221 for (
int d = 0; d < D_h; ++d) {
222 out_head[d] *= inv_s;
225 for (
int d = 0; d < D_h; ++d) {
232 #if defined(__AVX512F__)
238 #if CK_FLASH_ATTN_FAST_EXP
239 static inline __m512 ck_fast_exp512_ps(__m512 x) {
240 const __m512 max_val = _mm512_set1_ps(88.0f);
241 const __m512 min_val = _mm512_set1_ps(-88.0f);
242 x = _mm512_min_ps(x, max_val);
243 x = _mm512_max_ps(x, min_val);
245 const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
246 __m512 z = _mm512_mul_ps(x, log2e);
247 __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
248 __m512 f = _mm512_sub_ps(z, zf);
250 const __m512 c0 = _mm512_set1_ps(1.0f);
251 const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
252 const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
253 const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
254 const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
257 __m512 poly = _mm512_fmadd_ps(c4, f, c3);
258 poly = _mm512_fmadd_ps(poly, f, c2);
259 poly = _mm512_fmadd_ps(poly, f, c1);
260 poly = _mm512_fmadd_ps(poly, f, c0);
262 __m512 poly = _mm512_add_ps(_mm512_mul_ps(c4, f), c3);
263 poly = _mm512_add_ps(_mm512_mul_ps(poly, f), c2);
264 poly = _mm512_add_ps(_mm512_mul_ps(poly, f), c1);
265 poly = _mm512_add_ps(_mm512_mul_ps(poly, f), c0);
268 __m512i zi = _mm512_cvtps_epi32(zf);
269 zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
270 zi = _mm512_slli_epi32(zi, 23);
271 __m512 pow2 = _mm512_castsi512_ps(zi);
272 return _mm512_mul_ps(poly, pow2);
276 static inline float ck_dot_f32_avx512(
const float *q,
const float *k,
int D_h) {
277 __m512 sum0 = _mm512_setzero_ps();
278 __m512 sum1 = _mm512_setzero_ps();
281 for (; d + 32 <= D_h; d += 32) {
282 __m512 q0 = _mm512_loadu_ps(q + d);
283 __m512 k0 = _mm512_loadu_ps(k + d);
284 __m512 q1 = _mm512_loadu_ps(q + d + 16);
285 __m512 k1 = _mm512_loadu_ps(k + d + 16);
286 sum0 = _mm512_fmadd_ps(q0, k0, sum0);
287 sum1 = _mm512_fmadd_ps(q1, k1, sum1);
289 for (; d + 16 <= D_h; d += 16) {
290 __m512 q0 = _mm512_loadu_ps(q + d);
291 __m512 k0 = _mm512_loadu_ps(k + d);
292 sum0 = _mm512_fmadd_ps(q0, k0, sum0);
295 sum0 = _mm512_add_ps(sum0, sum1);
296 float dot = _mm512_reduce_add_ps(sum0);
297 for (; d < D_h; ++d) {
306 static void attention_flash_decode_avx512(
317 const int total = T_q * H;
318 const size_t stride = (size_t)H * (
size_t)D_h;
321 for (
int idx = 0; idx < total; ++idx) {
322 const int t_q = idx / H;
323 const int h = idx - t_q * H;
326 const float *q_head = q + (size_t)t_q * stride + (
size_t)h * (size_t)D_h;
327 float *out_head = out + (size_t)t_q * stride + (
size_t)h * (size_t)D_h;
328 const float *k_base = k + (size_t)h * (
size_t)D_h;
329 const float *v_base = v + (size_t)h * (
size_t)D_h;
332 for (; d + 16 <= D_h; d += 16) {
333 _mm512_storeu_ps(out_head + d, _mm512_setzero_ps());
335 for (; d < D_h; ++d) {
344 for (
int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
345 int blk_len = max_k - t_k0 + 1;
346 if (blk_len > tile_k) {
350 float m_block = -INFINITY;
351 for (
int bi = 0; bi < blk_len; ++bi) {
352 const int t_k = t_k0 + bi;
353 const float *k_head = k_base + (size_t)t_k * stride;
355 float dot = ck_dot_f32_avx512(q_head, k_head, D_h);
357 float score = dot * scale;
359 if (
score > m_block) {
365 float scale_old = (m == -INFINITY) ? 0.0f :
ck_expf(m - m_block);
367 __m512 scale_old_vec = _mm512_set1_ps(scale_old);
369 for (; d + 16 <= D_h; d += 16) {
370 __m512 out_v = _mm512_loadu_ps(out_head + d);
371 _mm512_storeu_ps(out_head + d, _mm512_mul_ps(out_v, scale_old_vec));
373 for (; d < D_h; ++d) {
374 out_head[d] *= scale_old;
379 #if CK_FLASH_ATTN_FAST_EXP
381 __m512 m_vec = _mm512_set1_ps(m);
382 for (; bi_vec + 16 <= blk_len; bi_vec += 16) {
383 __m512 s_vec = _mm512_loadu_ps(scores + bi_vec);
384 s_vec = _mm512_sub_ps(s_vec, m_vec);
385 __m512 w_vec = ck_fast_exp512_ps(s_vec);
386 _mm512_storeu_ps(scores + bi_vec, w_vec);
388 for (; bi_vec < blk_len; ++bi_vec) {
393 for (
int bi = 0; bi < blk_len; ++bi) {
394 const int t_k = t_k0 + bi;
395 const float *v_head = v_base + (size_t)t_k * stride;
396 #if CK_FLASH_ATTN_FAST_EXP
397 float w = scores[bi];
399 float w =
ck_expf(scores[bi] - m);
403 __m512 w_vec = _mm512_set1_ps(w);
405 for (; d + 16 <= D_h; d += 16) {
406 __m512 out_v = _mm512_loadu_ps(out_head + d);
407 __m512 v_v = _mm512_loadu_ps(v_head + d);
408 out_v = _mm512_fmadd_ps(w_vec, v_v, out_v);
409 _mm512_storeu_ps(out_head + d, out_v);
411 for (; d < D_h; ++d) {
412 out_head[d] += w * v_head[d];
418 float inv_s = 1.0f / s;
419 __m512 inv_s_vec = _mm512_set1_ps(inv_s);
421 for (; d + 16 <= D_h; d += 16) {
422 __m512 out_v = _mm512_loadu_ps(out_head + d);
423 _mm512_storeu_ps(out_head + d, _mm512_mul_ps(out_v, inv_s_vec));
425 for (; d < D_h; ++d) {
426 out_head[d] *= inv_s;
429 for (
int d0 = 0; d0 < D_h; ++d0) {
438 #if defined(__AVX__) && !defined(__AVX512F__)
444 #if CK_FLASH_ATTN_FAST_EXP
445 static inline __m256 ck_pow2_256_ps(__m256 zf) {
446 __m128 z0 = _mm256_castps256_ps128(zf);
447 __m128 z1 = _mm256_extractf128_ps(zf, 1);
449 __m128i i0 = _mm_cvtps_epi32(z0);
450 __m128i i1 = _mm_cvtps_epi32(z1);
451 i0 = _mm_add_epi32(i0, _mm_set1_epi32(127));
452 i1 = _mm_add_epi32(i1, _mm_set1_epi32(127));
453 i0 = _mm_slli_epi32(i0, 23);
454 i1 = _mm_slli_epi32(i1, 23);
456 __m128 f0 = _mm_castsi128_ps(i0);
457 __m128 f1 = _mm_castsi128_ps(i1);
458 __m256 out = _mm256_castps128_ps256(f0);
459 return _mm256_insertf128_ps(out, f1, 1);
462 static inline __m256 ck_fast_exp256_ps(__m256 x) {
463 const __m256 max_val = _mm256_set1_ps(88.0f);
464 const __m256 min_val = _mm256_set1_ps(-88.0f);
465 x = _mm256_min_ps(x, max_val);
466 x = _mm256_max_ps(x, min_val);
468 const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
469 __m256 z = _mm256_mul_ps(x, log2e);
470 __m256 zf = _mm256_round_ps(z, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
471 __m256 f = _mm256_sub_ps(z, zf);
473 const __m256 c0 = _mm256_set1_ps(1.0f);
474 const __m256 c1 = _mm256_set1_ps(0.6931471805599453f);
475 const __m256 c2 = _mm256_set1_ps(0.2402265069591007f);
476 const __m256 c3 = _mm256_set1_ps(0.05550410866482158f);
477 const __m256 c4 = _mm256_set1_ps(0.009618129107628478f);
480 __m256 poly = _mm256_fmadd_ps(c4, f, c3);
481 poly = _mm256_fmadd_ps(poly, f, c2);
482 poly = _mm256_fmadd_ps(poly, f, c1);
483 poly = _mm256_fmadd_ps(poly, f, c0);
485 __m256 poly = _mm256_add_ps(_mm256_mul_ps(c4, f), c3);
486 poly = _mm256_add_ps(_mm256_mul_ps(poly, f), c2);
487 poly = _mm256_add_ps(_mm256_mul_ps(poly, f), c1);
488 poly = _mm256_add_ps(_mm256_mul_ps(poly, f), c0);
491 __m256 pow2 = ck_pow2_256_ps(zf);
492 return _mm256_mul_ps(poly, pow2);
496 static inline float hsum256_ps(__m256 v) {
497 __m128 lo = _mm256_castps256_ps128(v);
498 __m128 hi = _mm256_extractf128_ps(v, 1);
499 __m128 sum128 = _mm_add_ps(lo, hi);
500 __m128 shuf = _mm_movehdup_ps(sum128);
501 __m128 sums = _mm_add_ps(sum128, shuf);
502 shuf = _mm_movehl_ps(shuf, sums);
503 sums = _mm_add_ps(shuf, sums);
504 return _mm_cvtss_f32(sums);
507 static inline float ck_dot_f32_avx(
const float *q,
const float *k,
int D_h) {
508 __m256 sum0 = _mm256_setzero_ps();
509 __m256 sum1 = _mm256_setzero_ps();
512 for (; d + 16 <= D_h; d += 16) {
513 __m256 q0 = _mm256_loadu_ps(q + d);
514 __m256 k0 = _mm256_loadu_ps(k + d);
515 __m256 q1 = _mm256_loadu_ps(q + d + 8);
516 __m256 k1 = _mm256_loadu_ps(k + d + 8);
518 sum0 = _mm256_fmadd_ps(q0, k0, sum0);
519 sum1 = _mm256_fmadd_ps(q1, k1, sum1);
521 sum0 = _mm256_add_ps(sum0, _mm256_mul_ps(q0, k0));
522 sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(q1, k1));
525 for (; d + 8 <= D_h; d += 8) {
526 __m256 q0 = _mm256_loadu_ps(q + d);
527 __m256 k0 = _mm256_loadu_ps(k + d);
529 sum0 = _mm256_fmadd_ps(q0, k0, sum0);
531 sum0 = _mm256_add_ps(sum0, _mm256_mul_ps(q0, k0));
535 __m256 sum = _mm256_add_ps(sum0, sum1);
536 float dot = hsum256_ps(sum);
537 for (; d < D_h; ++d) {
543 static void attention_flash_decode_avx(
554 const int total = T_q * H;
555 const size_t stride = (size_t)H * (
size_t)D_h;
558 for (
int idx = 0; idx < total; ++idx) {
559 const int t_q = idx / H;
560 const int h = idx - t_q * H;
563 const float *q_head = q + (size_t)t_q * stride + (
size_t)h * (size_t)D_h;
564 float *out_head = out + (size_t)t_q * stride + (
size_t)h * (size_t)D_h;
565 const float *k_base = k + (size_t)h * (
size_t)D_h;
566 const float *v_base = v + (size_t)h * (
size_t)D_h;
569 for (; d + 8 <= D_h; d += 8) {
570 _mm256_storeu_ps(out_head + d, _mm256_setzero_ps());
572 for (; d < D_h; ++d) {
581 for (
int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
582 int blk_len = max_k - t_k0 + 1;
583 if (blk_len > tile_k) {
587 float m_block = -INFINITY;
588 for (
int bi = 0; bi < blk_len; ++bi) {
589 const int t_k = t_k0 + bi;
590 const float *k_head = k_base + (size_t)t_k * stride;
592 float dot = ck_dot_f32_avx(q_head, k_head, D_h);
594 float score = dot * scale;
596 if (
score > m_block) {
602 float scale_old = (m == -INFINITY) ? 0.0f :
ck_expf(m - m_block);
604 __m256 scale_old_vec = _mm256_set1_ps(scale_old);
606 for (; d + 8 <= D_h; d += 8) {
607 __m256 out_v = _mm256_loadu_ps(out_head + d);
608 _mm256_storeu_ps(out_head + d, _mm256_mul_ps(out_v, scale_old_vec));
610 for (; d < D_h; ++d) {
611 out_head[d] *= scale_old;
616 #if CK_FLASH_ATTN_FAST_EXP
618 __m256 m_vec = _mm256_set1_ps(m);
619 for (; bi_vec + 8 <= blk_len; bi_vec += 8) {
620 __m256 s_vec = _mm256_loadu_ps(scores + bi_vec);
621 s_vec = _mm256_sub_ps(s_vec, m_vec);
622 __m256 w_vec = ck_fast_exp256_ps(s_vec);
623 _mm256_storeu_ps(scores + bi_vec, w_vec);
625 for (; bi_vec < blk_len; ++bi_vec) {
630 for (
int bi = 0; bi < blk_len; ++bi) {
631 const int t_k = t_k0 + bi;
632 const float *v_head = v_base + (size_t)t_k * stride;
633 #if CK_FLASH_ATTN_FAST_EXP
634 float w = scores[bi];
636 float w =
ck_expf(scores[bi] - m);
640 __m256 w_vec = _mm256_set1_ps(w);
642 for (; d + 8 <= D_h; d += 8) {
643 __m256 out_v = _mm256_loadu_ps(out_head + d);
644 __m256 v_v = _mm256_loadu_ps(v_head + d);
646 out_v = _mm256_fmadd_ps(w_vec, v_v, out_v);
648 out_v = _mm256_add_ps(out_v, _mm256_mul_ps(w_vec, v_v));
650 _mm256_storeu_ps(out_head + d, out_v);
652 for (; d < D_h; ++d) {
653 out_head[d] += w * v_head[d];
659 float inv_s = 1.0f / s;
660 __m256 inv_s_vec = _mm256_set1_ps(inv_s);
662 for (; d + 8 <= D_h; d += 8) {
663 __m256 out_v = _mm256_loadu_ps(out_head + d);
664 _mm256_storeu_ps(out_head + d, _mm256_mul_ps(out_v, inv_s_vec));
666 for (; d < D_h; ++d) {
667 out_head[d] *= inv_s;
670 for (
int d0 = 0; d0 < D_h; ++d0) {
707 if (!out || !q || !k || !v) {
710 if (T_q <= 0 || T_k <= 0 || H <= 0 || D_h <= 0) {
715 #if defined(__AVX512F__)
716 attention_flash_decode_avx512(out, q, k, v, T_q, T_k, H, D_h, scale);
717 #elif defined(__AVX__) && !defined(__AVX512F__)
718 attention_flash_decode_avx(out, q, k, v, T_q, T_k, H, D_h, scale);
int ck_flash_attn_choose_tile_k(int D_h)
void attention_flash_cleanup(void)
Clean up flash attention resources.
static void attention_flash_decode_scalar(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Scalar flash-style attention (online softmax)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
static float ck_fast_expf(float x)
static int max_k_for_query(int t_q, int T_q, int T_k)
void attention_flash_init(int max_context, int max_heads, int max_head_dim)
Initialize flash attention buffers.
static int ck_flash_attn_tile_k(int D_h)
int ck_flash_attn_fast_exp_kind(void)
static float ck_expf(float x)
#define CK_FLASH_ATTN_TILE_K