23 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
24 #include <immintrin.h>
30 if (!dst || !src)
return;
42 return ((
size_t)h * (
size_t)num_tokens + (
size_t)t) * (size_t)aligned_head_dim
51 int aligned_context_window)
53 return ((
size_t)h * (
size_t)aligned_context_window * (
size_t)aligned_context_window)
54 + (size_t)i * (
size_t)aligned_context_window
79 int aligned_context_window)
81 const float scale = 1.0f / sqrtf((
float)head_dim);
85 for (
int h = 0; h < num_heads; ++h) {
86 for (
int i = 0; i < num_tokens; ++i) {
87 for (
int j = 0; j <= i; ++j) {
89 size_t base_q =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
90 size_t base_k =
qkv_index(h, j, 0, num_tokens, aligned_head_dim);
92 for (
int d = 0; d < head_dim; ++d) {
93 dot += q[base_q + d] * k[base_k + d];
96 scores[
score_index(h, i, j, aligned_context_window)] = dot * scale;
101 for (
int j = i + 1; j < num_tokens; ++j) {
102 scores[
score_index(h, i, j, aligned_context_window)] = 0.0f;
111 aligned_context_window);
114 for (
int h = 0; h < num_heads; ++h) {
115 for (
int i = 0; i < num_tokens; ++i) {
116 size_t out_base =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
119 for (
int d = 0; d < aligned_head_dim; ++d) {
120 output[out_base + d] = 0.0f;
124 for (
int j = 0; j <= i; ++j) {
125 float w = scores[
score_index(h, i, j, aligned_context_window)];
126 size_t v_base =
qkv_index(h, j, 0, num_tokens, aligned_head_dim);
128 for (
int d = 0; d < head_dim; ++d) {
129 output[out_base + d] += w * v[v_base + d];
154 int aligned_head_dim,
155 int aligned_context_window)
157 const float scale = 1.0f / sqrtf((
float)head_dim);
161 for (
int h = 0; h < num_heads; ++h) {
162 for (
int i = 0; i < num_tokens; ++i) {
163 for (
int j = 0; j <= i; ++j) {
165 size_t base_q =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
166 size_t base_k =
qkv_index(h, j, 0, num_tokens, aligned_head_dim);
168 for (
int d = 0; d < head_dim; ++d) {
169 dot += q[base_q + d] * k[base_k + d];
172 scores[
score_index(h, i, j, aligned_context_window)] = dot * scale;
177 for (
int j = i + 1; j < num_tokens; ++j) {
178 scores[
score_index(h, i, j, aligned_context_window)] = 0.0f;
187 aligned_context_window);
190 for (
int h = 0; h < num_heads; ++h) {
191 for (
int i = 0; i < num_tokens; ++i) {
192 size_t out_base =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
195 for (
int d = 0; d < aligned_head_dim; ++d) {
196 output[out_base + d] = 0.0f;
200 for (
int j = 0; j <= i; ++j) {
201 float w = scores[
score_index(h, i, j, aligned_context_window)];
202 size_t v_base =
qkv_index(h, j, 0, num_tokens, aligned_head_dim);
204 for (
int d = 0; d < head_dim; ++d) {
205 output[out_base + d] += w * v[v_base + d];
233 int aligned_head_dim,
234 int aligned_context_window)
236 const float scale = 1.0f / sqrtf((
float)head_dim);
238 for (
int h = 0; h < num_heads; ++h) {
239 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
240 for (
int i = 0; i < num_tokens; ++i) {
241 for (
int j = 0; j <= i; ++j) {
243 size_t base_q =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
244 size_t base_k =
qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
246 for (
int d = 0; d < head_dim; ++d) {
247 dot += q[base_q + d] * k[base_k + d];
250 scores[
score_index(h, i, j, aligned_context_window)] = dot * scale;
253 for (
int j = i + 1; j < num_tokens; ++j) {
254 scores[
score_index(h, i, j, aligned_context_window)] = 0.0f;
262 aligned_context_window);
264 for (
int h = 0; h < num_heads; ++h) {
265 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
266 for (
int i = 0; i < num_tokens; ++i) {
267 size_t out_base =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
268 for (
int d = 0; d < aligned_head_dim; ++d) {
269 output[out_base + d] = 0.0f;
272 for (
int j = 0; j <= i; ++j) {
273 float w = scores[
score_index(h, i, j, aligned_context_window)];
274 size_t v_base =
qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
276 for (
int d = 0; d < head_dim; ++d) {
277 output[out_base + d] += w * v[v_base + d];
303 int aligned_head_dim,
304 int aligned_context_window)
306 const float scale = 1.0f / sqrtf((
float)head_dim);
308 for (
int h = 0; h < num_heads; ++h) {
309 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
310 for (
int i = 0; i < num_tokens; ++i) {
311 for (
int j = 0; j <= i; ++j) {
313 size_t base_q =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
314 size_t base_k =
qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
316 for (
int d = 0; d < head_dim; ++d) {
317 dot += q[base_q + d] * k[base_k + d];
320 scores[
score_index(h, i, j, aligned_context_window)] = dot * scale;
323 for (
int j = i + 1; j < num_tokens; ++j) {
324 scores[
score_index(h, i, j, aligned_context_window)] = 0.0f;
333 aligned_context_window);
335 for (
int h = 0; h < num_heads; ++h) {
336 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
337 for (
int i = 0; i < num_tokens; ++i) {
338 size_t out_base =
qkv_index(h, i, 0, num_tokens, aligned_head_dim);
339 for (
int d = 0; d < aligned_head_dim; ++d) {
340 output[out_base + d] = 0.0f;
343 for (
int j = 0; j <= i; ++j) {
344 float w = scores[
score_index(h, i, j, aligned_context_window)];
345 size_t v_base =
qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
347 for (
int d = 0; d < head_dim; ++d) {
348 output[out_base + d] += w * v[v_base + d];
375 int aligned_head_dim,
376 int aligned_context_window,
381 const size_t q_elems = (size_t)num_heads * (
size_t)num_tokens * (size_t)aligned_head_dim;
382 const size_t kv_elems = (size_t)num_kv_heads * (
size_t)num_tokens * (size_t)aligned_head_dim;
384 if (!scratch_q || !scratch_k || !scratch_v)
return;
394 num_heads, num_kv_heads,
395 num_tokens, head_dim,
396 aligned_head_dim, aligned_context_window);
414 #if defined(__AVX512F__)
415 static void attention_flash_query_causal_avx512(
const float *q_vec,
420 int aligned_head_dim,
430 for (; d + 16 <= aligned_head_dim; d += 16) {
431 _mm512_storeu_ps(&out_vec[d], _mm512_setzero_ps());
433 for (; d < aligned_head_dim; ++d) {
437 for (
int j = 0; j < kv_tokens; ++j) {
438 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
439 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
442 __m512 dot_acc = _mm512_setzero_ps();
444 for (; d + 16 <= head_dim; d += 16) {
445 __m512 q_v = _mm512_loadu_ps(&q_vec[d]);
446 __m512 k_v = _mm512_loadu_ps(&k_vec[d]);
447 dot_acc = _mm512_fmadd_ps(q_v, k_v, dot_acc);
449 float dot = _mm512_reduce_add_ps(dot_acc);
451 for (; d < head_dim; ++d) {
452 dot += q_vec[d] * k_vec[d];
454 float score = dot * scale;
457 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
461 __m512 exp_m_vec = _mm512_set1_ps(exp_m);
463 for (; d + 16 <= head_dim; d += 16) {
464 __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
465 __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
466 out_v = _mm512_fmadd_ps(out_v, exp_m_vec, v_v);
467 _mm512_storeu_ps(&out_vec[d], out_v);
469 for (; d < head_dim; ++d) {
470 out_vec[d] = out_vec[d] * exp_m + v_vec[d];
476 float e = expf(
score - m);
480 __m512 e_vec = _mm512_set1_ps(e);
482 for (; d + 16 <= head_dim; d += 16) {
483 __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
484 __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
485 out_v = _mm512_fmadd_ps(e_vec, v_v, out_v);
486 _mm512_storeu_ps(&out_vec[d], out_v);
488 for (; d < head_dim; ++d) {
489 out_vec[d] += e * v_vec[d];
495 float inv_s = 1.0f / s;
496 __m512 inv_s_vec = _mm512_set1_ps(inv_s);
498 for (; d + 16 <= head_dim; d += 16) {
499 __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
500 _mm512_storeu_ps(&out_vec[d], _mm512_mul_ps(out_v, inv_s_vec));
502 for (; d < head_dim; ++d) {
507 for (d = head_dim; d < aligned_head_dim; ++d) {
516 #if defined(__AVX2__)
517 static inline float hsum256_ps_flash(__m256 v) {
518 __m128 hi = _mm256_extractf128_ps(v, 1);
519 __m128 lo = _mm256_castps256_ps128(v);
520 __m128 sum128 = _mm_add_ps(lo, hi);
521 sum128 = _mm_hadd_ps(sum128, sum128);
522 sum128 = _mm_hadd_ps(sum128, sum128);
523 return _mm_cvtss_f32(sum128);
526 static void attention_flash_query_causal_avx2(
const float *q_vec,
531 int aligned_head_dim,
540 for (; d + 8 <= aligned_head_dim; d += 8) {
541 _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
543 for (; d < aligned_head_dim; ++d) {
547 for (
int j = 0; j < kv_tokens; ++j) {
548 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
549 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
552 __m256 dot_acc = _mm256_setzero_ps();
554 for (; d + 8 <= head_dim; d += 8) {
555 __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
556 __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
557 dot_acc = _mm256_fmadd_ps(q_v, k_v, dot_acc);
559 float dot = hsum256_ps_flash(dot_acc);
560 for (; d < head_dim; ++d) {
561 dot += q_vec[d] * k_vec[d];
563 float score = dot * scale;
566 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
569 __m256 exp_m_vec = _mm256_set1_ps(exp_m);
571 for (; d + 8 <= head_dim; d += 8) {
572 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
573 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
574 out_v = _mm256_fmadd_ps(out_v, exp_m_vec, v_v);
575 _mm256_storeu_ps(&out_vec[d], out_v);
577 for (; d < head_dim; ++d) {
578 out_vec[d] = out_vec[d] * exp_m + v_vec[d];
584 float e = expf(
score - m);
587 __m256 e_vec = _mm256_set1_ps(e);
589 for (; d + 8 <= head_dim; d += 8) {
590 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
591 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
592 out_v = _mm256_fmadd_ps(e_vec, v_v, out_v);
593 _mm256_storeu_ps(&out_vec[d], out_v);
595 for (; d < head_dim; ++d) {
596 out_vec[d] += e * v_vec[d];
602 float inv_s = 1.0f / s;
603 __m256 inv_s_vec = _mm256_set1_ps(inv_s);
605 for (; d + 8 <= head_dim; d += 8) {
606 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
607 _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
609 for (; d < head_dim; ++d) {
613 for (d = head_dim; d < aligned_head_dim; ++d) {
622 #if defined(__AVX__) && !defined(__AVX2__)
623 static inline float hsum256_ps_flash_avx(__m256 v) {
624 __m128 hi = _mm256_extractf128_ps(v, 1);
625 __m128 lo = _mm256_castps256_ps128(v);
626 __m128 sum128 = _mm_add_ps(lo, hi);
627 sum128 = _mm_hadd_ps(sum128, sum128);
628 sum128 = _mm_hadd_ps(sum128, sum128);
629 return _mm_cvtss_f32(sum128);
632 static void attention_flash_query_causal_avx(
const float *q_vec,
637 int aligned_head_dim,
646 for (; d + 8 <= aligned_head_dim; d += 8) {
647 _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
649 for (; d < aligned_head_dim; ++d) {
653 for (
int j = 0; j < kv_tokens; ++j) {
654 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
655 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
658 __m256 dot_acc = _mm256_setzero_ps();
660 for (; d + 8 <= head_dim; d += 8) {
661 __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
662 __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
663 dot_acc = _mm256_add_ps(dot_acc, _mm256_mul_ps(q_v, k_v));
665 float dot = hsum256_ps_flash_avx(dot_acc);
666 for (; d < head_dim; ++d) {
667 dot += q_vec[d] * k_vec[d];
669 float score = dot * scale;
672 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
675 __m256 exp_m_vec = _mm256_set1_ps(exp_m);
677 for (; d + 8 <= head_dim; d += 8) {
678 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
679 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
681 out_v = _mm256_add_ps(_mm256_mul_ps(out_v, exp_m_vec), v_v);
682 _mm256_storeu_ps(&out_vec[d], out_v);
684 for (; d < head_dim; ++d) {
685 out_vec[d] = out_vec[d] * exp_m + v_vec[d];
691 float e = expf(
score - m);
694 __m256 e_vec = _mm256_set1_ps(e);
696 for (; d + 8 <= head_dim; d += 8) {
697 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
698 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
700 out_v = _mm256_add_ps(out_v, _mm256_mul_ps(e_vec, v_v));
701 _mm256_storeu_ps(&out_vec[d], out_v);
703 for (; d < head_dim; ++d) {
704 out_vec[d] += e * v_vec[d];
710 float inv_s = 1.0f / s;
711 __m256 inv_s_vec = _mm256_set1_ps(inv_s);
713 for (; d + 8 <= head_dim; d += 8) {
714 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
715 _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
717 for (; d < head_dim; ++d) {
721 for (d = head_dim; d < aligned_head_dim; ++d) {
735 int aligned_head_dim,
745 for (
int d = 0; d < head_dim; ++d) {
749 for (
int j = 0; j < kv_tokens; ++j) {
750 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
751 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
754 for (
int d = 0; d < head_dim; ++d) {
755 dot += q_vec[d] * k_vec[d];
757 float score = dot * scale;
760 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
762 for (
int d = 0; d < head_dim; ++d) {
766 for (
int d = 0; d < head_dim; ++d) {
767 out_vec[d] += v_vec[d];
771 float e = expf(
score - m);
773 for (
int d = 0; d < head_dim; ++d) {
774 out_vec[d] += e * v_vec[d];
779 float inv_s = 1.0f / s;
780 for (
int d = 0; d < head_dim; ++d) {
783 for (
int d = head_dim; d < aligned_head_dim; ++d) {
808 int aligned_head_dim)
810 if (!q || !k || !v || !output) {
813 if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
817 const float scale = 1.0f / sqrtf((
float)head_dim);
818 const int T = num_tokens;
821 #if defined(__AVX512F__)
822 #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
823 #elif defined(__AVX2__)
824 #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
825 #elif defined(__AVX__)
826 #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
828 #define FLASH_QUERY_IMPL attention_flash_query_causal
831 for (
int h = 0; h < num_heads; ++h) {
832 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
833 const float *k_head = k + (size_t)kv_head * (
size_t)T * (size_t)aligned_head_dim;
834 const float *v_head = v + (size_t)kv_head * (
size_t)T * (size_t)aligned_head_dim;
836 for (
int i = 0; i < T; ++i) {
837 const float *q_vec = q +
qkv_index(h, i, 0, T, aligned_head_dim);
838 float *out_vec = output +
qkv_index(h, i, 0, T, aligned_head_dim);
841 head_dim, aligned_head_dim,
846 #undef FLASH_QUERY_IMPL
867 int aligned_head_dim,
868 int kv_stride_tokens)
870 if (!q || !k || !v || !output) {
873 if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
876 if (kv_stride_tokens < num_tokens) {
880 const float scale = 1.0f / sqrtf((
float)head_dim);
881 const int T = num_tokens;
882 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
885 #if defined(__AVX512F__)
886 #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
887 #elif defined(__AVX2__)
888 #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
889 #elif defined(__AVX__)
890 #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
892 #define FLASH_QUERY_IMPL attention_flash_query_causal
895 for (
int h = 0; h < num_heads; ++h) {
896 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
897 const float *k_head = k + (size_t)kv_head * kv_head_stride;
898 const float *v_head = v + (size_t)kv_head * kv_head_stride;
900 for (
int i = 0; i < T; ++i) {
901 const float *q_vec = q +
qkv_index(h, i, 0, T, aligned_head_dim);
902 float *out_vec = output +
qkv_index(h, i, 0, T, aligned_head_dim);
905 head_dim, aligned_head_dim,
910 #undef FLASH_QUERY_IMPL
928 #if defined(__AVX512F__)
929 static void attention_flash_query_sliding_avx512(
const float *q_vec,
935 int aligned_head_dim,
944 int window_start = 0;
945 if (sliding_window > 0) {
946 window_start = query_pos - sliding_window + 1;
947 if (window_start < 0) window_start = 0;
952 for (; d + 16 <= aligned_head_dim; d += 16) {
953 _mm512_storeu_ps(&out_vec[d], _mm512_setzero_ps());
955 for (; d < aligned_head_dim; ++d) {
960 int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
961 for (
int j = window_start; j <= effective_kv_end; ++j) {
962 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
963 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
966 __m512 dot_acc = _mm512_setzero_ps();
968 for (; d + 16 <= head_dim; d += 16) {
969 __m512 q_v = _mm512_loadu_ps(&q_vec[d]);
970 __m512 k_v = _mm512_loadu_ps(&k_vec[d]);
971 dot_acc = _mm512_fmadd_ps(q_v, k_v, dot_acc);
973 float dot = _mm512_reduce_add_ps(dot_acc);
974 for (; d < head_dim; ++d) {
975 dot += q_vec[d] * k_vec[d];
977 float score = dot * scale;
980 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
983 __m512 exp_m_vec = _mm512_set1_ps(exp_m);
985 for (; d + 16 <= head_dim; d += 16) {
986 __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
987 __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
988 out_v = _mm512_fmadd_ps(out_v, exp_m_vec, v_v);
989 _mm512_storeu_ps(&out_vec[d], out_v);
991 for (; d < head_dim; ++d) {
992 out_vec[d] = out_vec[d] * exp_m + v_vec[d];
998 float e = expf(
score - m);
1001 __m512 e_vec = _mm512_set1_ps(e);
1003 for (; d + 16 <= head_dim; d += 16) {
1004 __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
1005 __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
1006 out_v = _mm512_fmadd_ps(e_vec, v_v, out_v);
1007 _mm512_storeu_ps(&out_vec[d], out_v);
1009 for (; d < head_dim; ++d) {
1010 out_vec[d] += e * v_vec[d];
1016 float inv_s = 1.0f / s;
1017 __m512 inv_s_vec = _mm512_set1_ps(inv_s);
1019 for (; d + 16 <= head_dim; d += 16) {
1020 __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
1021 _mm512_storeu_ps(&out_vec[d], _mm512_mul_ps(out_v, inv_s_vec));
1023 for (; d < head_dim; ++d) {
1024 out_vec[d] *= inv_s;
1028 for (d = head_dim; d < aligned_head_dim; ++d) {
1037 #if defined(__AVX2__)
1038 static void attention_flash_query_sliding_avx2(
const float *q_vec,
1039 const float *k_head,
1040 const float *v_head,
1044 int aligned_head_dim,
1049 float m = -INFINITY;
1052 int window_start = 0;
1053 if (sliding_window > 0) {
1054 window_start = query_pos - sliding_window + 1;
1055 if (window_start < 0) window_start = 0;
1059 for (; d + 8 <= aligned_head_dim; d += 8) {
1060 _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
1062 for (; d < aligned_head_dim; ++d) {
1066 int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1067 for (
int j = window_start; j <= effective_kv_end; ++j) {
1068 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
1069 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
1071 __m256 dot_acc = _mm256_setzero_ps();
1073 for (; d + 8 <= head_dim; d += 8) {
1074 __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
1075 __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
1076 dot_acc = _mm256_fmadd_ps(q_v, k_v, dot_acc);
1078 float dot = hsum256_ps_flash(dot_acc);
1079 for (; d < head_dim; ++d) {
1080 dot += q_vec[d] * k_vec[d];
1082 float score = dot * scale;
1085 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
1088 __m256 exp_m_vec = _mm256_set1_ps(exp_m);
1090 for (; d + 8 <= head_dim; d += 8) {
1091 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1092 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1093 out_v = _mm256_fmadd_ps(out_v, exp_m_vec, v_v);
1094 _mm256_storeu_ps(&out_vec[d], out_v);
1096 for (; d < head_dim; ++d) {
1097 out_vec[d] = out_vec[d] * exp_m + v_vec[d];
1103 float e = expf(
score - m);
1106 __m256 e_vec = _mm256_set1_ps(e);
1108 for (; d + 8 <= head_dim; d += 8) {
1109 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1110 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1111 out_v = _mm256_fmadd_ps(e_vec, v_v, out_v);
1112 _mm256_storeu_ps(&out_vec[d], out_v);
1114 for (; d < head_dim; ++d) {
1115 out_vec[d] += e * v_vec[d];
1120 float inv_s = 1.0f / s;
1121 __m256 inv_s_vec = _mm256_set1_ps(inv_s);
1123 for (; d + 8 <= head_dim; d += 8) {
1124 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1125 _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
1127 for (; d < head_dim; ++d) {
1128 out_vec[d] *= inv_s;
1131 for (d = head_dim; d < aligned_head_dim; ++d) {
1140 #if defined(__AVX__) && !defined(__AVX2__)
1141 static void attention_flash_query_sliding_avx(
const float *q_vec,
1142 const float *k_head,
1143 const float *v_head,
1147 int aligned_head_dim,
1152 float m = -INFINITY;
1155 int window_start = 0;
1156 if (sliding_window > 0) {
1157 window_start = query_pos - sliding_window + 1;
1158 if (window_start < 0) window_start = 0;
1162 for (; d + 8 <= aligned_head_dim; d += 8) {
1163 _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
1165 for (; d < aligned_head_dim; ++d) {
1169 int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1170 for (
int j = window_start; j <= effective_kv_end; ++j) {
1171 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
1172 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
1174 __m256 dot_acc = _mm256_setzero_ps();
1176 for (; d + 8 <= head_dim; d += 8) {
1177 __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
1178 __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
1179 dot_acc = _mm256_add_ps(dot_acc, _mm256_mul_ps(q_v, k_v));
1181 float dot = hsum256_ps_flash_avx(dot_acc);
1182 for (; d < head_dim; ++d) {
1183 dot += q_vec[d] * k_vec[d];
1185 float score = dot * scale;
1188 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
1191 __m256 exp_m_vec = _mm256_set1_ps(exp_m);
1193 for (; d + 8 <= head_dim; d += 8) {
1194 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1195 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1196 out_v = _mm256_add_ps(_mm256_mul_ps(out_v, exp_m_vec), v_v);
1197 _mm256_storeu_ps(&out_vec[d], out_v);
1199 for (; d < head_dim; ++d) {
1200 out_vec[d] = out_vec[d] * exp_m + v_vec[d];
1206 float e = expf(
score - m);
1209 __m256 e_vec = _mm256_set1_ps(e);
1211 for (; d + 8 <= head_dim; d += 8) {
1212 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1213 __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1214 out_v = _mm256_add_ps(out_v, _mm256_mul_ps(e_vec, v_v));
1215 _mm256_storeu_ps(&out_vec[d], out_v);
1217 for (; d < head_dim; ++d) {
1218 out_vec[d] += e * v_vec[d];
1223 float inv_s = 1.0f / s;
1224 __m256 inv_s_vec = _mm256_set1_ps(inv_s);
1226 for (; d + 8 <= head_dim; d += 8) {
1227 __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1228 _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
1230 for (; d < head_dim; ++d) {
1231 out_vec[d] *= inv_s;
1234 for (d = head_dim; d < aligned_head_dim; ++d) {
1244 const float *k_head,
1245 const float *v_head,
1249 int aligned_head_dim,
1254 float m = -INFINITY;
1257 int window_start = 0;
1258 if (sliding_window > 0) {
1259 window_start = query_pos - sliding_window + 1;
1260 if (window_start < 0) window_start = 0;
1263 for (
int d = 0; d < head_dim; ++d) {
1267 int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1268 for (
int j = window_start; j <= effective_kv_end; ++j) {
1269 const float *k_vec = k_head + (size_t)j * (
size_t)aligned_head_dim;
1270 const float *v_vec = v_head + (size_t)j * (
size_t)aligned_head_dim;
1273 for (
int d = 0; d < head_dim; ++d) {
1274 dot += q_vec[d] * k_vec[d];
1276 float score = dot * scale;
1279 float exp_m = (m == -INFINITY) ? 0.0f : expf(m -
score);
1281 for (
int d = 0; d < head_dim; ++d) {
1282 out_vec[d] *= exp_m;
1285 for (
int d = 0; d < head_dim; ++d) {
1286 out_vec[d] += v_vec[d];
1290 float e = expf(
score - m);
1292 for (
int d = 0; d < head_dim; ++d) {
1293 out_vec[d] += e * v_vec[d];
1298 float inv_s = 1.0f / s;
1299 for (
int d = 0; d < head_dim; ++d) {
1300 out_vec[d] *= inv_s;
1302 for (
int d = head_dim; d < aligned_head_dim; ++d) {
1325 int aligned_head_dim,
1326 int kv_stride_tokens,
1329 if (!q || !k || !v || !output) {
1332 if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
1335 if (kv_stride_tokens < num_tokens) {
1339 const float scale = 1.0f / sqrtf((
float)head_dim);
1340 const int T = num_tokens;
1341 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
1343 #if defined(__AVX512F__)
1344 #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx512
1345 #elif defined(__AVX2__)
1346 #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx2
1347 #elif defined(__AVX__)
1348 #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx
1350 #define SLIDING_FLASH_IMPL attention_flash_query_sliding
1353 for (
int h = 0; h < num_heads; ++h) {
1354 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
1355 const float *k_head = k + (size_t)kv_head * kv_head_stride;
1356 const float *v_head = v + (size_t)kv_head * kv_head_stride;
1358 for (
int i = 0; i < T; ++i) {
1359 const float *q_vec = q +
qkv_index(h, i, 0, T, aligned_head_dim);
1360 float *out_vec = output +
qkv_index(h, i, 0, T, aligned_head_dim);
1364 head_dim, aligned_head_dim,
1370 #undef SLIDING_FLASH_IMPL
1383 const float *q_token,
1384 const float *k_cache,
1385 const float *v_cache,
1392 int aligned_head_dim,
1395 if (!q_token || !k_cache || !v_cache || !out_token) {
1398 if (num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
1401 if (kv_tokens <= 0 || kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1405 const float scale = 1.0f / sqrtf((
float)head_dim);
1406 const size_t head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
1409 int effective_kv_tokens = kv_tokens;
1410 if (sliding_window > 0 && sliding_window < kv_tokens) {
1411 effective_kv_tokens = sliding_window;
1415 if (effective_kv_tokens <= 0) {
1420 int kv_start_offset = kv_tokens - effective_kv_tokens;
1422 #if defined(__AVX512F__)
1423 #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx512
1424 #elif defined(__AVX2__)
1425 #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx2
1426 #elif defined(__AVX__)
1427 #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx
1429 #define SLIDING_DECODE_IMPL attention_flash_query_sliding
1432 for (
int h = 0; h < num_heads; ++h) {
1433 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
1434 const float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
1436 const float *k_head = k_cache + (size_t)kv_head * head_stride
1437 + (
size_t)kv_start_offset * (size_t)aligned_head_dim;
1438 const float *v_head = v_cache + (size_t)kv_head * head_stride
1439 + (
size_t)kv_start_offset * (size_t)aligned_head_dim;
1440 float *out_head = out_token + (size_t)h * (
size_t)aligned_head_dim;
1445 effective_kv_tokens - 1,
1446 effective_kv_tokens,
1447 head_dim, aligned_head_dim,
1452 #undef SLIDING_DECODE_IMPL
1468 const float *k_cache,
1469 const float *v_cache,
1476 int aligned_head_dim)
1478 if (!q_token || !k_cache || !v_cache || !out_token) {
1481 if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1484 if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1488 const float scale = 1.0f / sqrtf((
float)head_dim);
1489 const size_t head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
1491 for (
int h = 0; h < num_heads; ++h) {
1492 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
1493 const float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
1494 const float *k_head = k_cache + (size_t)kv_head * head_stride;
1495 const float *v_head = v_cache + (size_t)kv_head * head_stride;
1496 float *out_head = out_token + (size_t)h * (
size_t)aligned_head_dim;
1525 const float *k_cache,
1526 const float *v_cache,
1533 int aligned_head_dim)
1535 if (!q_token || !k_cache || !v_cache || !out_token) {
1538 if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1541 if (kv_tokens > cache_capacity) {
1545 const float scale = 1.0f / sqrtf((
float)head_dim);
1546 const size_t head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
1549 #if defined(__AVX512F__)
1550 #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx512
1551 #elif defined(__AVX2__)
1552 #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx2
1553 #elif defined(__AVX__)
1554 #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx
1556 #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal
1559 #pragma omp parallel for schedule(static) if(num_heads > 1)
1560 for (
int h = 0; h < num_heads; ++h) {
1561 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
1562 const float *q_vec = q_token + (size_t)h * (
size_t)aligned_head_dim;
1563 const float *k_head = k_cache + (size_t)kv_head * head_stride;
1564 const float *v_head = v_cache + (size_t)kv_head * head_stride;
1565 float *out_vec = out_token + (size_t)h * (
size_t)aligned_head_dim;
1568 kv_tokens, head_dim, aligned_head_dim,
1572 #undef FLASH_QUERY_IMPL_DECODE
1620 const uint16_t *d_output,
1625 const float *attn_weights,
1634 int aligned_head_dim,
1635 int aligned_context_window,
1636 float *scratch_d_output,
1642 const size_t head_elems = (size_t)num_heads * (
size_t)num_tokens * (size_t)aligned_head_dim;
1643 const size_t kv_elems = (size_t)num_kv_heads * (
size_t)num_tokens * (size_t)aligned_head_dim;
1645 if (!scratch_d_output || !scratch_q || !scratch_k || !scratch_v)
return;
1654 d_q, d_k, d_v, d_scores,
1655 num_heads, num_kv_heads,
1656 num_tokens, head_dim,
1657 aligned_head_dim, aligned_context_window);
1673 const float *d_output,
1677 const float *attn_weights,
1686 int aligned_head_dim,
1687 int aligned_context_window)
1689 const float scale = 1.0f / sqrtf((
float)head_dim);
1692 int H_kv = num_kv_heads;
1694 int ad = aligned_head_dim;
1695 int aw = aligned_context_window;
1697 const size_t d_q_elems = (size_t)H * (
size_t)T * (size_t)ad;
1698 const size_t kv_elems = (size_t)H_kv * (
size_t)T * (size_t)ad;
1700 for (
size_t idx = 0; idx < d_q_elems; ++idx) {
1703 for (
size_t idx = 0; idx < kv_elems; ++idx) {
1709 for (
int h = 0; h < H; ++h) {
1711 int kv_h = (int)((
long long)h * (
long long)H_kv / (
long long)H);
1719 for (
int i = 0; i < T; ++i) {
1720 size_t d_out_base =
qkv_index(h, i, 0, T, ad);
1722 for (
int j = 0; j <= i; ++j) {
1723 size_t v_base =
qkv_index(kv_h, j, 0, T, ad);
1725 float w = attn_weights[w_idx];
1729 for (
int dd = 0; dd < hd; ++dd) {
1730 dot += d_output[d_out_base + dd] * v[v_base + dd];
1732 d_scores[w_idx] = dot;
1735 for (
int dd = 0; dd < hd; ++dd) {
1736 d_v[v_base + dd] += w * d_output[d_out_base + dd];
1741 for (
int j = i + 1; j < T; ++j) {
1745 for (
int j = T; j < aw; ++j) {
1756 for (
int i = 0; i < T; ++i) {
1757 int base = h * aw * aw + i * aw;
1760 float dot_product = 0.0f;
1761 for (
int j = 0; j <= i; ++j) {
1762 float wt = attn_weights[base + j];
1763 float dw = d_scores[base + j];
1764 dot_product += wt * dw;
1768 for (
int j = 0; j <= i; ++j) {
1769 float wt = attn_weights[base + j];
1770 float dw = d_scores[base + j];
1771 d_scores[base + j] = wt * (dw - dot_product);
1780 for (
int i = 0; i < T; ++i) {
1781 size_t d_q_base =
qkv_index(h, i, 0, T, ad);
1782 size_t q_base =
qkv_index(h, i, 0, T, ad);
1786 for (
int j = 0; j <= i; ++j) {
1787 size_t k_base =
qkv_index(kv_h, j, 0, T, ad);
1788 size_t d_k_base =
qkv_index(kv_h, j, 0, T, ad);
1789 float ds = d_scores[
score_index(h, i, j, aw)] * scale;
1791 for (
int dd = 0; dd < hd; ++dd) {
1792 d_q[d_q_base + dd] += ds * k[k_base + dd];
1793 d_k[d_k_base + dd] += ds * q[q_base + dd];
1812 const float *d_output,
1816 const float *attn_weights,
1824 int aligned_head_dim,
1825 int aligned_context_window)
1828 d_output, q, k, v, attn_weights,
1829 d_q, d_k, d_v, d_scores,
1830 num_heads, num_heads,
1831 num_tokens, head_dim, aligned_head_dim, aligned_context_window);
static size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)
void attention_forward_causal_head_major_gqa_bf16(const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v)
void attention_forward_causal_head_major_gqa_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void attention_backward_causal_head_major_gqa_bf16(const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v)
void attention_forward_causal_head_major(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
void attention_forward_causal_head_major_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
static size_t score_index(int h, int i, int j, int aligned_context_window)
#define SLIDING_DECODE_IMPL
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
static void attention_flash_query_sliding(const float *q_vec, const float *k_head, const float *v_head, int query_pos, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec, int sliding_window)
static void attention_flash_query_causal(const float *q_vec, const float *k_head, const float *v_head, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
#define SLIDING_FLASH_IMPL
#define FLASH_QUERY_IMPL_DECODE
void attention_backward_causal_head_major(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
static void convert_bf16_tensor_to_buf(const uint16_t *src, float *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)