← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_kernels.c
Go to the documentation of this file.
1 /**
2  * @file attention_kernels.c
3  * @brief Attention score/softmax/output kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Attention: softmax(Q @ K^T / sqrt(d)) @ V
15  * Supports GQA (grouped-query attention) with head broadcasting.
16  */
17 
18 #include "bf16_utils.h"
19 #include "ckernel_engine.h"
20 #include <math.h>
21 #include <stdlib.h>
22 
23 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
24 #include <immintrin.h>
25 #endif
26 
27 /* Convert BF16 tensor to FP32 using caller-provided buffer (no malloc!) */
28 static void convert_bf16_tensor_to_buf(const uint16_t *src, float *dst, size_t count)
29 {
30  if (!dst || !src) return;
31  bf16_tensor_to_float(src, dst, count);
32 }
33 
34 // Helpers for head-major layouts used in attention.
35 // Q/K/V layout: [head][token][head_dim] with stride aligned_head_dim.
36 static inline size_t qkv_index(int h,
37  int t,
38  int d,
39  int num_tokens,
40  int aligned_head_dim)
41 {
42  return ((size_t)h * (size_t)num_tokens + (size_t)t) * (size_t)aligned_head_dim
43  + (size_t)d;
44 }
45 
46 // Scores layout matches causal_softmax_head_major:
47 // [head][query_token][key_token] with stride aligned_context_window.
48 static inline size_t score_index(int h,
49  int i,
50  int j,
51  int aligned_context_window)
52 {
53  return ((size_t)h * (size_t)aligned_context_window * (size_t)aligned_context_window)
54  + (size_t)i * (size_t)aligned_context_window
55  + (size_t)j;
56 }
57 
58 /**
59  * Causal attention forward (score-matrix version)
60  * @test test_attention.py::TestAttentionForward::test_causal_forward
61  * @test test_attention.py::TestAttentionForward::test_gqa_broadcast
62  * @test test_attention.py::TestAttentionForward::test_exact_vs_fast
63  * @test test_parity.py::test_attention_parity
64  *
65  * Computes softmax(Q @ K^T / sqrt(d)) @ V with causal masking.
66  * Uses O(N^2) memory for scores matrix.
67  *
68  * After changes: make test && make llamacpp-parity-full
69  */
71  const float *k,
72  const float *v,
73  float *scores,
74  float *output,
75  int num_heads,
76  int num_tokens,
77  int head_dim,
78  int aligned_head_dim,
79  int aligned_context_window)
80 {
81  const float scale = 1.0f / sqrtf((float)head_dim);
82 
83  // Phase 1: compute scaled dot-product scores Q·K^T / sqrt(d_k),
84  // lower triangle only (j <= i).
85  for (int h = 0; h < num_heads; ++h) {
86  for (int i = 0; i < num_tokens; ++i) {
87  for (int j = 0; j <= i; ++j) {
88  float dot = 0.0f;
89  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
90  size_t base_k = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
91 
92  for (int d = 0; d < head_dim; ++d) {
93  dot += q[base_q + d] * k[base_k + d];
94  }
95 
96  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
97  }
98 
99  // Ensure upper triangle is zeroed so there are no stale values
100  // before the softmax kernel runs.
101  for (int j = i + 1; j < num_tokens; ++j) {
102  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
103  }
104  }
105  }
106 
107  // Phase 2: apply causal row-wise softmax in-place over j <= i.
109  num_heads,
110  num_tokens,
111  aligned_context_window);
112 
113  // Phase 3: attention weights · V.
114  for (int h = 0; h < num_heads; ++h) {
115  for (int i = 0; i < num_tokens; ++i) {
116  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
117 
118  // Zero the full aligned head slice so padded dims stay clean.
119  for (int d = 0; d < aligned_head_dim; ++d) {
120  output[out_base + d] = 0.0f;
121  }
122 
123  // Weighted sum over causal positions.
124  for (int j = 0; j <= i; ++j) {
125  float w = scores[score_index(h, i, j, aligned_context_window)];
126  size_t v_base = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
127 
128  for (int d = 0; d < head_dim; ++d) {
129  output[out_base + d] += w * v[v_base + d];
130  }
131  }
132  }
133  }
134 }
135 
136 /**
137  * Causal attention forward (exact version using stdlib expf)
138  * @test test_attention.py::TestAttentionForward::test_exact_single
139  * @test test_attention.py::TestAttentionForward::test_exact_vs_fast
140  *
141  * Uses standard library expf for numerical accuracy reference.
142  * Slower but provides maximum accuracy.
143  *
144  * After changes: make test
145  */
147  const float *k,
148  const float *v,
149  float *scores,
150  float *output,
151  int num_heads,
152  int num_tokens,
153  int head_dim,
154  int aligned_head_dim,
155  int aligned_context_window)
156 {
157  const float scale = 1.0f / sqrtf((float)head_dim);
158 
159  // Phase 1: compute scaled dot-product scores Q·K^T / sqrt(d_k),
160  // lower triangle only (j <= i).
161  for (int h = 0; h < num_heads; ++h) {
162  for (int i = 0; i < num_tokens; ++i) {
163  for (int j = 0; j <= i; ++j) {
164  float dot = 0.0f;
165  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
166  size_t base_k = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
167 
168  for (int d = 0; d < head_dim; ++d) {
169  dot += q[base_q + d] * k[base_k + d];
170  }
171 
172  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
173  }
174 
175  // Ensure upper triangle is zeroed so there are no stale values
176  // before the softmax kernel runs.
177  for (int j = i + 1; j < num_tokens; ++j) {
178  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
179  }
180  }
181  }
182 
183  // Phase 2: apply causal row-wise softmax using exact expf.
185  num_heads,
186  num_tokens,
187  aligned_context_window);
188 
189  // Phase 3: attention weights · V.
190  for (int h = 0; h < num_heads; ++h) {
191  for (int i = 0; i < num_tokens; ++i) {
192  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
193 
194  // Zero the full aligned head slice so padded dims stay clean.
195  for (int d = 0; d < aligned_head_dim; ++d) {
196  output[out_base + d] = 0.0f;
197  }
198 
199  // Weighted sum over causal positions.
200  for (int j = 0; j <= i; ++j) {
201  float w = scores[score_index(h, i, j, aligned_context_window)];
202  size_t v_base = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
203 
204  for (int d = 0; d < head_dim; ++d) {
205  output[out_base + d] += w * v[v_base + d];
206  }
207  }
208  }
209  }
210 }
211 
212 /**
213  * GQA causal attention forward (score-matrix version)
214  * @test test_attention.py::TestAttentionForward::test_gqa_forward
215  * @test test_attention.py::TestAttentionForward::test_gqa_broadcast
216  * @test test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward
217  * @test test_parity.py::test_attention_gqa_parity
218  *
219  * Grouped-query attention: Q has num_heads, K/V have num_kv_heads.
220  * Each query head maps to a KV head via ratio.
221  *
222  * After changes: make test && make llamacpp-parity-full
223  */
225  const float *k,
226  const float *v,
227  float *scores,
228  float *output,
229  int num_heads,
230  int num_kv_heads,
231  int num_tokens,
232  int head_dim,
233  int aligned_head_dim,
234  int aligned_context_window)
235 {
236  const float scale = 1.0f / sqrtf((float)head_dim);
237 
238  for (int h = 0; h < num_heads; ++h) {
239  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
240  for (int i = 0; i < num_tokens; ++i) {
241  for (int j = 0; j <= i; ++j) {
242  float dot = 0.0f;
243  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
244  size_t base_k = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
245 
246  for (int d = 0; d < head_dim; ++d) {
247  dot += q[base_q + d] * k[base_k + d];
248  }
249 
250  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
251  }
252 
253  for (int j = i + 1; j < num_tokens; ++j) {
254  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
255  }
256  }
257  }
258 
260  num_heads,
261  num_tokens,
262  aligned_context_window);
263 
264  for (int h = 0; h < num_heads; ++h) {
265  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
266  for (int i = 0; i < num_tokens; ++i) {
267  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
268  for (int d = 0; d < aligned_head_dim; ++d) {
269  output[out_base + d] = 0.0f;
270  }
271 
272  for (int j = 0; j <= i; ++j) {
273  float w = scores[score_index(h, i, j, aligned_context_window)];
274  size_t v_base = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
275 
276  for (int d = 0; d < head_dim; ++d) {
277  output[out_base + d] += w * v[v_base + d];
278  }
279  }
280  }
281  }
282 }
283 
284 /**
285  * GQA causal attention forward (exact version using stdlib expf)
286  * @test test_attention.py::TestAttentionForward::test_gqa_exact
287  * @test bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
288  *
289  * Uses standard library expf for numerical accuracy reference.
290  * Used by BF16 wrapper to avoid approximation error accumulation.
291  *
292  * After changes: make test
293  */
295  const float *k,
296  const float *v,
297  float *scores,
298  float *output,
299  int num_heads,
300  int num_kv_heads,
301  int num_tokens,
302  int head_dim,
303  int aligned_head_dim,
304  int aligned_context_window)
305 {
306  const float scale = 1.0f / sqrtf((float)head_dim);
307 
308  for (int h = 0; h < num_heads; ++h) {
309  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
310  for (int i = 0; i < num_tokens; ++i) {
311  for (int j = 0; j <= i; ++j) {
312  float dot = 0.0f;
313  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
314  size_t base_k = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
315 
316  for (int d = 0; d < head_dim; ++d) {
317  dot += q[base_q + d] * k[base_k + d];
318  }
319 
320  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
321  }
322 
323  for (int j = i + 1; j < num_tokens; ++j) {
324  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
325  }
326  }
327  }
328 
329  // Use exact softmax with standard library expf
331  num_heads,
332  num_tokens,
333  aligned_context_window);
334 
335  for (int h = 0; h < num_heads; ++h) {
336  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
337  for (int i = 0; i < num_tokens; ++i) {
338  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
339  for (int d = 0; d < aligned_head_dim; ++d) {
340  output[out_base + d] = 0.0f;
341  }
342 
343  for (int j = 0; j <= i; ++j) {
344  float w = scores[score_index(h, i, j, aligned_context_window)];
345  size_t v_base = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
346 
347  for (int d = 0; d < head_dim; ++d) {
348  output[out_base + d] += w * v[v_base + d];
349  }
350  }
351  }
352  }
353 }
354 
355 /**
356  * BF16 GQA causal attention forward
357  * @test bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward
358  * @test bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
359  * @test bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash
360  *
361  * Accepts BF16 inputs, converts to FP32, uses exact softmax.
362  * Caller provides scratch buffers (no per-call malloc).
363  *
364  * After changes: make test
365  */
367  const uint16_t *k,
368  const uint16_t *v,
369  float *scores,
370  float *output,
371  int num_heads,
372  int num_kv_heads,
373  int num_tokens,
374  int head_dim,
375  int aligned_head_dim,
376  int aligned_context_window,
377  float *scratch_q,
378  float *scratch_k,
379  float *scratch_v)
380 {
381  const size_t q_elems = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
382  const size_t kv_elems = (size_t)num_kv_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
383 
384  if (!scratch_q || !scratch_k || !scratch_v) return;
385 
386  convert_bf16_tensor_to_buf(q, scratch_q, q_elems);
387  convert_bf16_tensor_to_buf(k, scratch_k, kv_elems);
388  convert_bf16_tensor_to_buf(v, scratch_v, kv_elems);
389 
390  // Use exact version to avoid fast exp approximation error accumulating
391  // with BF16 precision loss.
392  attention_forward_causal_head_major_gqa_exact(scratch_q, scratch_k, scratch_v,
393  scores, output,
394  num_heads, num_kv_heads,
395  num_tokens, head_dim,
396  aligned_head_dim, aligned_context_window);
397  /* No free - caller owns scratch buffers */
398 }
399 
400 // ============================================================================
401 // ATTENTION FORWARD - Flash-style (no scores materialization)
402 // ============================================================================
403 //
404 // Computes the same causal attention output as `attention_forward_causal_head_major_gqa`,
405 // but does not materialize the [H, T, T] score/weight matrices. This is useful for:
406 // - Prefill: avoids large scratch buffers and improves cache locality
407 // - Decode: supports KV-cache attention for a single token
408 //
409 // SIMD-optimized implementations for AVX-512, AVX2, and AVX follow.
410 
411 // ============================================================================
412 // AVX-512 SIMD Flash Attention (16 floats per vector)
413 // ============================================================================
414 #if defined(__AVX512F__)
415 static void attention_flash_query_causal_avx512(const float *q_vec,
416  const float *k_head,
417  const float *v_head,
418  int kv_tokens,
419  int head_dim,
420  int aligned_head_dim,
421  float scale,
422  float *out_vec)
423 {
424  // Online softmax: m = running max, s = running sum(exp(score - m))
425  float m = -INFINITY;
426  float s = 0.0f;
427 
428  // Zero output using SIMD
429  int d = 0;
430  for (; d + 16 <= aligned_head_dim; d += 16) {
431  _mm512_storeu_ps(&out_vec[d], _mm512_setzero_ps());
432  }
433  for (; d < aligned_head_dim; ++d) {
434  out_vec[d] = 0.0f;
435  }
436 
437  for (int j = 0; j < kv_tokens; ++j) {
438  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
439  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
440 
441  // Vectorized dot product Q·K
442  __m512 dot_acc = _mm512_setzero_ps();
443  d = 0;
444  for (; d + 16 <= head_dim; d += 16) {
445  __m512 q_v = _mm512_loadu_ps(&q_vec[d]);
446  __m512 k_v = _mm512_loadu_ps(&k_vec[d]);
447  dot_acc = _mm512_fmadd_ps(q_v, k_v, dot_acc);
448  }
449  float dot = _mm512_reduce_add_ps(dot_acc);
450  // Scalar tail
451  for (; d < head_dim; ++d) {
452  dot += q_vec[d] * k_vec[d];
453  }
454  float score = dot * scale;
455 
456  if (score > m) {
457  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
458  s *= exp_m;
459 
460  // Vectorized: out *= exp_m, then out += v
461  __m512 exp_m_vec = _mm512_set1_ps(exp_m);
462  d = 0;
463  for (; d + 16 <= head_dim; d += 16) {
464  __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
465  __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
466  out_v = _mm512_fmadd_ps(out_v, exp_m_vec, v_v);
467  _mm512_storeu_ps(&out_vec[d], out_v);
468  }
469  for (; d < head_dim; ++d) {
470  out_vec[d] = out_vec[d] * exp_m + v_vec[d];
471  }
472 
473  s += 1.0f;
474  m = score;
475  } else {
476  float e = expf(score - m);
477  s += e;
478 
479  // Vectorized: out += e * v
480  __m512 e_vec = _mm512_set1_ps(e);
481  d = 0;
482  for (; d + 16 <= head_dim; d += 16) {
483  __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
484  __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
485  out_v = _mm512_fmadd_ps(e_vec, v_v, out_v);
486  _mm512_storeu_ps(&out_vec[d], out_v);
487  }
488  for (; d < head_dim; ++d) {
489  out_vec[d] += e * v_vec[d];
490  }
491  }
492  }
493 
494  // Normalize: out /= s
495  float inv_s = 1.0f / s;
496  __m512 inv_s_vec = _mm512_set1_ps(inv_s);
497  d = 0;
498  for (; d + 16 <= head_dim; d += 16) {
499  __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
500  _mm512_storeu_ps(&out_vec[d], _mm512_mul_ps(out_v, inv_s_vec));
501  }
502  for (; d < head_dim; ++d) {
503  out_vec[d] *= inv_s;
504  }
505 
506  // Zero padding
507  for (d = head_dim; d < aligned_head_dim; ++d) {
508  out_vec[d] = 0.0f;
509  }
510 }
511 #endif // __AVX512F__
512 
513 // ============================================================================
514 // AVX2 SIMD Flash Attention (8 floats per vector)
515 // ============================================================================
516 #if defined(__AVX2__)
517 static inline float hsum256_ps_flash(__m256 v) {
518  __m128 hi = _mm256_extractf128_ps(v, 1);
519  __m128 lo = _mm256_castps256_ps128(v);
520  __m128 sum128 = _mm_add_ps(lo, hi);
521  sum128 = _mm_hadd_ps(sum128, sum128);
522  sum128 = _mm_hadd_ps(sum128, sum128);
523  return _mm_cvtss_f32(sum128);
524 }
525 
526 static void attention_flash_query_causal_avx2(const float *q_vec,
527  const float *k_head,
528  const float *v_head,
529  int kv_tokens,
530  int head_dim,
531  int aligned_head_dim,
532  float scale,
533  float *out_vec)
534 {
535  float m = -INFINITY;
536  float s = 0.0f;
537 
538  // Zero output using SIMD
539  int d = 0;
540  for (; d + 8 <= aligned_head_dim; d += 8) {
541  _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
542  }
543  for (; d < aligned_head_dim; ++d) {
544  out_vec[d] = 0.0f;
545  }
546 
547  for (int j = 0; j < kv_tokens; ++j) {
548  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
549  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
550 
551  // Vectorized dot product Q·K
552  __m256 dot_acc = _mm256_setzero_ps();
553  d = 0;
554  for (; d + 8 <= head_dim; d += 8) {
555  __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
556  __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
557  dot_acc = _mm256_fmadd_ps(q_v, k_v, dot_acc);
558  }
559  float dot = hsum256_ps_flash(dot_acc);
560  for (; d < head_dim; ++d) {
561  dot += q_vec[d] * k_vec[d];
562  }
563  float score = dot * scale;
564 
565  if (score > m) {
566  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
567  s *= exp_m;
568 
569  __m256 exp_m_vec = _mm256_set1_ps(exp_m);
570  d = 0;
571  for (; d + 8 <= head_dim; d += 8) {
572  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
573  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
574  out_v = _mm256_fmadd_ps(out_v, exp_m_vec, v_v);
575  _mm256_storeu_ps(&out_vec[d], out_v);
576  }
577  for (; d < head_dim; ++d) {
578  out_vec[d] = out_vec[d] * exp_m + v_vec[d];
579  }
580 
581  s += 1.0f;
582  m = score;
583  } else {
584  float e = expf(score - m);
585  s += e;
586 
587  __m256 e_vec = _mm256_set1_ps(e);
588  d = 0;
589  for (; d + 8 <= head_dim; d += 8) {
590  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
591  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
592  out_v = _mm256_fmadd_ps(e_vec, v_v, out_v);
593  _mm256_storeu_ps(&out_vec[d], out_v);
594  }
595  for (; d < head_dim; ++d) {
596  out_vec[d] += e * v_vec[d];
597  }
598  }
599  }
600 
601  // Normalize
602  float inv_s = 1.0f / s;
603  __m256 inv_s_vec = _mm256_set1_ps(inv_s);
604  d = 0;
605  for (; d + 8 <= head_dim; d += 8) {
606  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
607  _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
608  }
609  for (; d < head_dim; ++d) {
610  out_vec[d] *= inv_s;
611  }
612 
613  for (d = head_dim; d < aligned_head_dim; ++d) {
614  out_vec[d] = 0.0f;
615  }
616 }
617 #endif // __AVX2__
618 
619 // ============================================================================
620 // AVX SIMD Flash Attention (8 floats per vector, no FMA)
621 // ============================================================================
622 #if defined(__AVX__) && !defined(__AVX2__)
623 static inline float hsum256_ps_flash_avx(__m256 v) {
624  __m128 hi = _mm256_extractf128_ps(v, 1);
625  __m128 lo = _mm256_castps256_ps128(v);
626  __m128 sum128 = _mm_add_ps(lo, hi);
627  sum128 = _mm_hadd_ps(sum128, sum128);
628  sum128 = _mm_hadd_ps(sum128, sum128);
629  return _mm_cvtss_f32(sum128);
630 }
631 
632 static void attention_flash_query_causal_avx(const float *q_vec,
633  const float *k_head,
634  const float *v_head,
635  int kv_tokens,
636  int head_dim,
637  int aligned_head_dim,
638  float scale,
639  float *out_vec)
640 {
641  float m = -INFINITY;
642  float s = 0.0f;
643 
644  // Zero output using SIMD
645  int d = 0;
646  for (; d + 8 <= aligned_head_dim; d += 8) {
647  _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
648  }
649  for (; d < aligned_head_dim; ++d) {
650  out_vec[d] = 0.0f;
651  }
652 
653  for (int j = 0; j < kv_tokens; ++j) {
654  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
655  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
656 
657  // Vectorized dot product Q·K (no FMA, use mul + add)
658  __m256 dot_acc = _mm256_setzero_ps();
659  d = 0;
660  for (; d + 8 <= head_dim; d += 8) {
661  __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
662  __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
663  dot_acc = _mm256_add_ps(dot_acc, _mm256_mul_ps(q_v, k_v));
664  }
665  float dot = hsum256_ps_flash_avx(dot_acc);
666  for (; d < head_dim; ++d) {
667  dot += q_vec[d] * k_vec[d];
668  }
669  float score = dot * scale;
670 
671  if (score > m) {
672  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
673  s *= exp_m;
674 
675  __m256 exp_m_vec = _mm256_set1_ps(exp_m);
676  d = 0;
677  for (; d + 8 <= head_dim; d += 8) {
678  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
679  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
680  // out = out * exp_m + v (no FMA)
681  out_v = _mm256_add_ps(_mm256_mul_ps(out_v, exp_m_vec), v_v);
682  _mm256_storeu_ps(&out_vec[d], out_v);
683  }
684  for (; d < head_dim; ++d) {
685  out_vec[d] = out_vec[d] * exp_m + v_vec[d];
686  }
687 
688  s += 1.0f;
689  m = score;
690  } else {
691  float e = expf(score - m);
692  s += e;
693 
694  __m256 e_vec = _mm256_set1_ps(e);
695  d = 0;
696  for (; d + 8 <= head_dim; d += 8) {
697  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
698  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
699  // out = out + e * v (no FMA)
700  out_v = _mm256_add_ps(out_v, _mm256_mul_ps(e_vec, v_v));
701  _mm256_storeu_ps(&out_vec[d], out_v);
702  }
703  for (; d < head_dim; ++d) {
704  out_vec[d] += e * v_vec[d];
705  }
706  }
707  }
708 
709  // Normalize
710  float inv_s = 1.0f / s;
711  __m256 inv_s_vec = _mm256_set1_ps(inv_s);
712  d = 0;
713  for (; d + 8 <= head_dim; d += 8) {
714  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
715  _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
716  }
717  for (; d < head_dim; ++d) {
718  out_vec[d] *= inv_s;
719  }
720 
721  for (d = head_dim; d < aligned_head_dim; ++d) {
722  out_vec[d] = 0.0f;
723  }
724 }
725 #endif // __AVX__ && !__AVX2__
726 
727 // ============================================================================
728 // Scalar fallback (original implementation)
729 // ============================================================================
730 static void attention_flash_query_causal(const float *q_vec,
731  const float *k_head,
732  const float *v_head,
733  int kv_tokens,
734  int head_dim,
735  int aligned_head_dim,
736  float scale,
737  float *out_vec)
738 {
739  // Online softmax:
740  // m = running max, s = running sum(exp(score - m))
741  // out = sum(exp(score - m) * v)
742  float m = -INFINITY;
743  float s = 0.0f;
744 
745  for (int d = 0; d < head_dim; ++d) {
746  out_vec[d] = 0.0f;
747  }
748 
749  for (int j = 0; j < kv_tokens; ++j) {
750  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
751  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
752 
753  float dot = 0.0f;
754  for (int d = 0; d < head_dim; ++d) {
755  dot += q_vec[d] * k_vec[d];
756  }
757  float score = dot * scale;
758 
759  if (score > m) {
760  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
761  s *= exp_m;
762  for (int d = 0; d < head_dim; ++d) {
763  out_vec[d] *= exp_m;
764  }
765  s += 1.0f;
766  for (int d = 0; d < head_dim; ++d) {
767  out_vec[d] += v_vec[d];
768  }
769  m = score;
770  } else {
771  float e = expf(score - m);
772  s += e;
773  for (int d = 0; d < head_dim; ++d) {
774  out_vec[d] += e * v_vec[d];
775  }
776  }
777  }
778 
779  float inv_s = 1.0f / s;
780  for (int d = 0; d < head_dim; ++d) {
781  out_vec[d] *= inv_s;
782  }
783  for (int d = head_dim; d < aligned_head_dim; ++d) {
784  out_vec[d] = 0.0f;
785  }
786 }
787 
788 /**
789  * Flash attention forward for GQA (prefill, no score materialization)
790  * @test test_flash_attention.py::TestFlashAttention::test_flash_forward
791  * @test test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix
792  * @test test_flash_attention.py::TestFlashAttention::test_flash_gqa
793  * @test test_attention.py::TestAttentionForward::test_flash_forward
794  *
795  * Online softmax with streaming KV. O(N) memory instead of O(N^2).
796  * For prefill: all tokens attend to previous tokens.
797  *
798  * After changes: make test && make llamacpp-parity-full
799  */
801  const float *k,
802  const float *v,
803  float *output,
804  int num_heads,
805  int num_kv_heads,
806  int num_tokens,
807  int head_dim,
808  int aligned_head_dim)
809 {
810  if (!q || !k || !v || !output) {
811  return;
812  }
813  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
814  return;
815  }
816 
817  const float scale = 1.0f / sqrtf((float)head_dim);
818  const int T = num_tokens;
819 
820  // Select SIMD implementation based on compile-time CPU features
821 #if defined(__AVX512F__)
822  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
823 #elif defined(__AVX2__)
824  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
825 #elif defined(__AVX__)
826  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
827 #else
828  #define FLASH_QUERY_IMPL attention_flash_query_causal
829 #endif
830 
831  for (int h = 0; h < num_heads; ++h) {
832  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
833  const float *k_head = k + (size_t)kv_head * (size_t)T * (size_t)aligned_head_dim;
834  const float *v_head = v + (size_t)kv_head * (size_t)T * (size_t)aligned_head_dim;
835 
836  for (int i = 0; i < T; ++i) {
837  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
838  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
839  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
840  /*kv_tokens=*/i + 1,
841  head_dim, aligned_head_dim,
842  scale, out_vec);
843  }
844  }
845 
846 #undef FLASH_QUERY_IMPL
847 }
848 
849 /**
850  * Flash attention forward with custom KV stride (for KV cache)
851  * @test test_flash_attention.py::TestFlashAttention::test_flash_strided
852  * @test test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention
853  *
854  * Variant with configurable kv_stride_tokens for KV cache layouts
855  * where K/V may not be contiguous in memory.
856  *
857  * After changes: make test
858  */
860  const float *k,
861  const float *v,
862  float *output,
863  int num_heads,
864  int num_kv_heads,
865  int num_tokens,
866  int head_dim,
867  int aligned_head_dim,
868  int kv_stride_tokens)
869 {
870  if (!q || !k || !v || !output) {
871  return;
872  }
873  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
874  return;
875  }
876  if (kv_stride_tokens < num_tokens) {
877  return;
878  }
879 
880  const float scale = 1.0f / sqrtf((float)head_dim);
881  const int T = num_tokens;
882  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
883 
884  // Select SIMD implementation based on compile-time CPU features
885 #if defined(__AVX512F__)
886  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
887 #elif defined(__AVX2__)
888  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
889 #elif defined(__AVX__)
890  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
891 #else
892  #define FLASH_QUERY_IMPL attention_flash_query_causal
893 #endif
894 
895  for (int h = 0; h < num_heads; ++h) {
896  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
897  const float *k_head = k + (size_t)kv_head * kv_head_stride;
898  const float *v_head = v + (size_t)kv_head * kv_head_stride;
899 
900  for (int i = 0; i < T; ++i) {
901  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
902  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
903  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
904  /*kv_tokens=*/i + 1,
905  head_dim, aligned_head_dim,
906  scale, out_vec);
907  }
908  }
909 
910 #undef FLASH_QUERY_IMPL
911 }
912 
913 // ============================================================================
914 // SLIDING-WINDOW ATTENTION - Flash-style with sliding window mask
915 // ============================================================================
916 //
917 // Sliding-window attention: each token attends only to the last W tokens.
918 // For token at position i, the valid key range is [max(0, i - W + 1) .. i].
919 // This is equivalent to causal attention with a window size limit.
920 //
921 // Key difference from regular causal attention:
922 // - Regular causal: token i attends to [0 .. i] (all previous tokens)
923 // - Sliding window: token i attends to [max(0, i - W + 1) .. i] (last W tokens only)
924 
925 // ============================================================================
926 // AVX-512 Sliding-Window Flash Attention
927 // ============================================================================
928 #if defined(__AVX512F__)
929 static void attention_flash_query_sliding_avx512(const float *q_vec,
930  const float *k_head,
931  const float *v_head,
932  int query_pos, // Position of query token (0-indexed)
933  int kv_tokens, // Total KV tokens available
934  int head_dim,
935  int aligned_head_dim,
936  float scale,
937  float *out_vec,
938  int sliding_window) // Window size (0 = no limit)
939 {
940  float m = -INFINITY;
941  float s = 0.0f;
942 
943  // Compute sliding window bounds
944  int window_start = 0;
945  if (sliding_window > 0) {
946  window_start = query_pos - sliding_window + 1;
947  if (window_start < 0) window_start = 0;
948  }
949 
950  // Zero output using SIMD
951  int d = 0;
952  for (; d + 16 <= aligned_head_dim; d += 16) {
953  _mm512_storeu_ps(&out_vec[d], _mm512_setzero_ps());
954  }
955  for (; d < aligned_head_dim; ++d) {
956  out_vec[d] = 0.0f;
957  }
958 
959  // Process only tokens in the sliding window [window_start .. min(query_pos, kv_tokens-1)]
960  int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
961  for (int j = window_start; j <= effective_kv_end; ++j) {
962  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
963  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
964 
965  // Vectorized dot product Q·K
966  __m512 dot_acc = _mm512_setzero_ps();
967  d = 0;
968  for (; d + 16 <= head_dim; d += 16) {
969  __m512 q_v = _mm512_loadu_ps(&q_vec[d]);
970  __m512 k_v = _mm512_loadu_ps(&k_vec[d]);
971  dot_acc = _mm512_fmadd_ps(q_v, k_v, dot_acc);
972  }
973  float dot = _mm512_reduce_add_ps(dot_acc);
974  for (; d < head_dim; ++d) {
975  dot += q_vec[d] * k_vec[d];
976  }
977  float score = dot * scale;
978 
979  if (score > m) {
980  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
981  s *= exp_m;
982 
983  __m512 exp_m_vec = _mm512_set1_ps(exp_m);
984  d = 0;
985  for (; d + 16 <= head_dim; d += 16) {
986  __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
987  __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
988  out_v = _mm512_fmadd_ps(out_v, exp_m_vec, v_v);
989  _mm512_storeu_ps(&out_vec[d], out_v);
990  }
991  for (; d < head_dim; ++d) {
992  out_vec[d] = out_vec[d] * exp_m + v_vec[d];
993  }
994 
995  s += 1.0f;
996  m = score;
997  } else {
998  float e = expf(score - m);
999  s += e;
1000 
1001  __m512 e_vec = _mm512_set1_ps(e);
1002  d = 0;
1003  for (; d + 16 <= head_dim; d += 16) {
1004  __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
1005  __m512 v_v = _mm512_loadu_ps(&v_vec[d]);
1006  out_v = _mm512_fmadd_ps(e_vec, v_v, out_v);
1007  _mm512_storeu_ps(&out_vec[d], out_v);
1008  }
1009  for (; d < head_dim; ++d) {
1010  out_vec[d] += e * v_vec[d];
1011  }
1012  }
1013  }
1014 
1015  // Normalize: out /= s
1016  float inv_s = 1.0f / s;
1017  __m512 inv_s_vec = _mm512_set1_ps(inv_s);
1018  d = 0;
1019  for (; d + 16 <= head_dim; d += 16) {
1020  __m512 out_v = _mm512_loadu_ps(&out_vec[d]);
1021  _mm512_storeu_ps(&out_vec[d], _mm512_mul_ps(out_v, inv_s_vec));
1022  }
1023  for (; d < head_dim; ++d) {
1024  out_vec[d] *= inv_s;
1025  }
1026 
1027  // Zero padding
1028  for (d = head_dim; d < aligned_head_dim; ++d) {
1029  out_vec[d] = 0.0f;
1030  }
1031 }
1032 #endif // __AVX512F__
1033 
1034 // ============================================================================
1035 // AVX2 Sliding-Window Flash Attention
1036 // ============================================================================
1037 #if defined(__AVX2__)
1038 static void attention_flash_query_sliding_avx2(const float *q_vec,
1039  const float *k_head,
1040  const float *v_head,
1041  int query_pos,
1042  int kv_tokens,
1043  int head_dim,
1044  int aligned_head_dim,
1045  float scale,
1046  float *out_vec,
1047  int sliding_window)
1048 {
1049  float m = -INFINITY;
1050  float s = 0.0f;
1051 
1052  int window_start = 0;
1053  if (sliding_window > 0) {
1054  window_start = query_pos - sliding_window + 1;
1055  if (window_start < 0) window_start = 0;
1056  }
1057 
1058  int d = 0;
1059  for (; d + 8 <= aligned_head_dim; d += 8) {
1060  _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
1061  }
1062  for (; d < aligned_head_dim; ++d) {
1063  out_vec[d] = 0.0f;
1064  }
1065 
1066  int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1067  for (int j = window_start; j <= effective_kv_end; ++j) {
1068  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
1069  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
1070 
1071  __m256 dot_acc = _mm256_setzero_ps();
1072  d = 0;
1073  for (; d + 8 <= head_dim; d += 8) {
1074  __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
1075  __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
1076  dot_acc = _mm256_fmadd_ps(q_v, k_v, dot_acc);
1077  }
1078  float dot = hsum256_ps_flash(dot_acc);
1079  for (; d < head_dim; ++d) {
1080  dot += q_vec[d] * k_vec[d];
1081  }
1082  float score = dot * scale;
1083 
1084  if (score > m) {
1085  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
1086  s *= exp_m;
1087 
1088  __m256 exp_m_vec = _mm256_set1_ps(exp_m);
1089  d = 0;
1090  for (; d + 8 <= head_dim; d += 8) {
1091  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1092  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1093  out_v = _mm256_fmadd_ps(out_v, exp_m_vec, v_v);
1094  _mm256_storeu_ps(&out_vec[d], out_v);
1095  }
1096  for (; d < head_dim; ++d) {
1097  out_vec[d] = out_vec[d] * exp_m + v_vec[d];
1098  }
1099 
1100  s += 1.0f;
1101  m = score;
1102  } else {
1103  float e = expf(score - m);
1104  s += e;
1105 
1106  __m256 e_vec = _mm256_set1_ps(e);
1107  d = 0;
1108  for (; d + 8 <= head_dim; d += 8) {
1109  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1110  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1111  out_v = _mm256_fmadd_ps(e_vec, v_v, out_v);
1112  _mm256_storeu_ps(&out_vec[d], out_v);
1113  }
1114  for (; d < head_dim; ++d) {
1115  out_vec[d] += e * v_vec[d];
1116  }
1117  }
1118  }
1119 
1120  float inv_s = 1.0f / s;
1121  __m256 inv_s_vec = _mm256_set1_ps(inv_s);
1122  d = 0;
1123  for (; d + 8 <= head_dim; d += 8) {
1124  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1125  _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
1126  }
1127  for (; d < head_dim; ++d) {
1128  out_vec[d] *= inv_s;
1129  }
1130 
1131  for (d = head_dim; d < aligned_head_dim; ++d) {
1132  out_vec[d] = 0.0f;
1133  }
1134 }
1135 #endif // __AVX2__
1136 
1137 // ============================================================================
1138 // AVX Sliding-Window Flash Attention (no FMA)
1139 // ============================================================================
1140 #if defined(__AVX__) && !defined(__AVX2__)
1141 static void attention_flash_query_sliding_avx(const float *q_vec,
1142  const float *k_head,
1143  const float *v_head,
1144  int query_pos,
1145  int kv_tokens,
1146  int head_dim,
1147  int aligned_head_dim,
1148  float scale,
1149  float *out_vec,
1150  int sliding_window)
1151 {
1152  float m = -INFINITY;
1153  float s = 0.0f;
1154 
1155  int window_start = 0;
1156  if (sliding_window > 0) {
1157  window_start = query_pos - sliding_window + 1;
1158  if (window_start < 0) window_start = 0;
1159  }
1160 
1161  int d = 0;
1162  for (; d + 8 <= aligned_head_dim; d += 8) {
1163  _mm256_storeu_ps(&out_vec[d], _mm256_setzero_ps());
1164  }
1165  for (; d < aligned_head_dim; ++d) {
1166  out_vec[d] = 0.0f;
1167  }
1168 
1169  int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1170  for (int j = window_start; j <= effective_kv_end; ++j) {
1171  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
1172  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
1173 
1174  __m256 dot_acc = _mm256_setzero_ps();
1175  d = 0;
1176  for (; d + 8 <= head_dim; d += 8) {
1177  __m256 q_v = _mm256_loadu_ps(&q_vec[d]);
1178  __m256 k_v = _mm256_loadu_ps(&k_vec[d]);
1179  dot_acc = _mm256_add_ps(dot_acc, _mm256_mul_ps(q_v, k_v));
1180  }
1181  float dot = hsum256_ps_flash_avx(dot_acc);
1182  for (; d < head_dim; ++d) {
1183  dot += q_vec[d] * k_vec[d];
1184  }
1185  float score = dot * scale;
1186 
1187  if (score > m) {
1188  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
1189  s *= exp_m;
1190 
1191  __m256 exp_m_vec = _mm256_set1_ps(exp_m);
1192  d = 0;
1193  for (; d + 8 <= head_dim; d += 8) {
1194  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1195  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1196  out_v = _mm256_add_ps(_mm256_mul_ps(out_v, exp_m_vec), v_v);
1197  _mm256_storeu_ps(&out_vec[d], out_v);
1198  }
1199  for (; d < head_dim; ++d) {
1200  out_vec[d] = out_vec[d] * exp_m + v_vec[d];
1201  }
1202 
1203  s += 1.0f;
1204  m = score;
1205  } else {
1206  float e = expf(score - m);
1207  s += e;
1208 
1209  __m256 e_vec = _mm256_set1_ps(e);
1210  d = 0;
1211  for (; d + 8 <= head_dim; d += 8) {
1212  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1213  __m256 v_v = _mm256_loadu_ps(&v_vec[d]);
1214  out_v = _mm256_add_ps(out_v, _mm256_mul_ps(e_vec, v_v));
1215  _mm256_storeu_ps(&out_vec[d], out_v);
1216  }
1217  for (; d < head_dim; ++d) {
1218  out_vec[d] += e * v_vec[d];
1219  }
1220  }
1221  }
1222 
1223  float inv_s = 1.0f / s;
1224  __m256 inv_s_vec = _mm256_set1_ps(inv_s);
1225  d = 0;
1226  for (; d + 8 <= head_dim; d += 8) {
1227  __m256 out_v = _mm256_loadu_ps(&out_vec[d]);
1228  _mm256_storeu_ps(&out_vec[d], _mm256_mul_ps(out_v, inv_s_vec));
1229  }
1230  for (; d < head_dim; ++d) {
1231  out_vec[d] *= inv_s;
1232  }
1233 
1234  for (d = head_dim; d < aligned_head_dim; ++d) {
1235  out_vec[d] = 0.0f;
1236  }
1237 }
1238 #endif // __AVX__ && !__AVX2__
1239 
1240 // ============================================================================
1241 // Scalar Sliding-Window Flash Attention Fallback
1242 // ============================================================================
1243 static void attention_flash_query_sliding(const float *q_vec,
1244  const float *k_head,
1245  const float *v_head,
1246  int query_pos,
1247  int kv_tokens,
1248  int head_dim,
1249  int aligned_head_dim,
1250  float scale,
1251  float *out_vec,
1252  int sliding_window)
1253 {
1254  float m = -INFINITY;
1255  float s = 0.0f;
1256 
1257  int window_start = 0;
1258  if (sliding_window > 0) {
1259  window_start = query_pos - sliding_window + 1;
1260  if (window_start < 0) window_start = 0;
1261  }
1262 
1263  for (int d = 0; d < head_dim; ++d) {
1264  out_vec[d] = 0.0f;
1265  }
1266 
1267  int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1268  for (int j = window_start; j <= effective_kv_end; ++j) {
1269  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
1270  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
1271 
1272  float dot = 0.0f;
1273  for (int d = 0; d < head_dim; ++d) {
1274  dot += q_vec[d] * k_vec[d];
1275  }
1276  float score = dot * scale;
1277 
1278  if (score > m) {
1279  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
1280  s *= exp_m;
1281  for (int d = 0; d < head_dim; ++d) {
1282  out_vec[d] *= exp_m;
1283  }
1284  s += 1.0f;
1285  for (int d = 0; d < head_dim; ++d) {
1286  out_vec[d] += v_vec[d];
1287  }
1288  m = score;
1289  } else {
1290  float e = expf(score - m);
1291  s += e;
1292  for (int d = 0; d < head_dim; ++d) {
1293  out_vec[d] += e * v_vec[d];
1294  }
1295  }
1296  }
1297 
1298  float inv_s = 1.0f / s;
1299  for (int d = 0; d < head_dim; ++d) {
1300  out_vec[d] *= inv_s;
1301  }
1302  for (int d = head_dim; d < aligned_head_dim; ++d) {
1303  out_vec[d] = 0.0f;
1304  }
1305 }
1306 
1307 /**
1308  * Flash attention forward with sliding window (prefill)
1309  * @test test_attention.py::TestAttentionForward::test_sliding_window_prefill
1310  *
1311  * Sliding-window attention for prefill: each token attends to the last W tokens.
1312  * When sliding_window <= 0, behaves like regular causal attention.
1313  *
1314  * After changes: make test
1315  */
1317  const float *q,
1318  const float *k,
1319  const float *v,
1320  float *output,
1321  int num_heads,
1322  int num_kv_heads,
1323  int num_tokens,
1324  int head_dim,
1325  int aligned_head_dim,
1326  int kv_stride_tokens,
1327  int sliding_window)
1328 {
1329  if (!q || !k || !v || !output) {
1330  return;
1331  }
1332  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
1333  return;
1334  }
1335  if (kv_stride_tokens < num_tokens) {
1336  return;
1337  }
1338 
1339  const float scale = 1.0f / sqrtf((float)head_dim);
1340  const int T = num_tokens;
1341  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
1342 
1343 #if defined(__AVX512F__)
1344  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx512
1345 #elif defined(__AVX2__)
1346  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx2
1347 #elif defined(__AVX__)
1348  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx
1349 #else
1350  #define SLIDING_FLASH_IMPL attention_flash_query_sliding
1351 #endif
1352 
1353  for (int h = 0; h < num_heads; ++h) {
1354  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1355  const float *k_head = k + (size_t)kv_head * kv_head_stride;
1356  const float *v_head = v + (size_t)kv_head * kv_head_stride;
1357 
1358  for (int i = 0; i < T; ++i) {
1359  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
1360  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
1361  SLIDING_FLASH_IMPL(q_vec, k_head, v_head,
1362  /*query_pos=*/i,
1363  /*kv_tokens=*/T,
1364  head_dim, aligned_head_dim,
1365  scale, out_vec,
1366  sliding_window);
1367  }
1368  }
1369 
1370 #undef SLIDING_FLASH_IMPL
1371 }
1372 
1373 /**
1374  * Flash attention decode with sliding window
1375  * @test test_attention.py::TestAttentionForward::test_sliding_window_decode
1376  *
1377  * Single query token attends to the last W tokens in the KV cache.
1378  * For decode: effective_kv_tokens = min(kv_tokens, sliding_window)
1379  *
1380  * After changes: make test
1381  */
1383  const float *q_token,
1384  const float *k_cache,
1385  const float *v_cache,
1386  float *out_token,
1387  int num_heads,
1388  int num_kv_heads,
1389  int kv_tokens,
1390  int cache_capacity,
1391  int head_dim,
1392  int aligned_head_dim,
1393  int sliding_window)
1394 {
1395  if (!q_token || !k_cache || !v_cache || !out_token) {
1396  return;
1397  }
1398  if (num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
1399  return;
1400  }
1401  if (kv_tokens <= 0 || kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1402  return;
1403  }
1404 
1405  const float scale = 1.0f / sqrtf((float)head_dim);
1406  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1407 
1408  // Compute effective KV tokens based on sliding window
1409  int effective_kv_tokens = kv_tokens;
1410  if (sliding_window > 0 && sliding_window < kv_tokens) {
1411  effective_kv_tokens = sliding_window;
1412  }
1413 
1414  // Guard against empty window (shouldn't happen with kv_tokens >= 1)
1415  if (effective_kv_tokens <= 0) {
1416  return;
1417  }
1418 
1419  // Offset to start reading from the last effective_kv_tokens entries
1420  int kv_start_offset = kv_tokens - effective_kv_tokens;
1421 
1422 #if defined(__AVX512F__)
1423  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx512
1424 #elif defined(__AVX2__)
1425  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx2
1426 #elif defined(__AVX__)
1427  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx
1428 #else
1429  #define SLIDING_DECODE_IMPL attention_flash_query_sliding
1430 #endif
1431 
1432  for (int h = 0; h < num_heads; ++h) {
1433  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1434  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1435  // Offset K/V pointer to start from the first token in the sliding window
1436  const float *k_head = k_cache + (size_t)kv_head * head_stride
1437  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1438  const float *v_head = v_cache + (size_t)kv_head * head_stride
1439  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1440  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1441 
1442  // Use query_pos relative to the windowed KV (last token = effective_kv_tokens - 1)
1443  // sliding_window = 0 since we've already windowed via K/V pointer offset
1444  SLIDING_DECODE_IMPL(q_head, k_head, v_head,
1445  /*query_pos=*/effective_kv_tokens - 1,
1446  /*kv_tokens=*/effective_kv_tokens,
1447  head_dim, aligned_head_dim,
1448  scale, out_head,
1449  /*sliding_window=*/0);
1450  }
1451 
1452 #undef SLIDING_DECODE_IMPL
1453 }
1454 
1455 /**
1456  * Flash attention decode (single token attends to KV cache)
1457  * @test test_flash_attention.py::TestFlashAttention::test_flash_decode
1458  * @test test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode
1459  * @test test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode
1460  * @test test_attention.py::TestAttentionForward::test_flash_decode
1461  *
1462  * Single query token attends to kv_tokens in KV cache.
1463  * Uses true flash attention from attention_flash_true.c.
1464  *
1465  * After changes: make test && make llamacpp-parity-full
1466  */
1468  const float *k_cache,
1469  const float *v_cache,
1470  float *out_token,
1471  int num_heads,
1472  int num_kv_heads,
1473  int kv_tokens,
1474  int cache_capacity,
1475  int head_dim,
1476  int aligned_head_dim)
1477 {
1478  if (!q_token || !k_cache || !v_cache || !out_token) {
1479  return;
1480  }
1481  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1482  return;
1483  }
1484  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1485  return;
1486  }
1487 
1488  const float scale = 1.0f / sqrtf((float)head_dim);
1489  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1490 
1491  for (int h = 0; h < num_heads; ++h) {
1492  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1493  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1494  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1495  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1496  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1497 
1498  attention_flash_decode(out_head,
1499  q_head,
1500  k_head,
1501  v_head,
1502  1,
1503  kv_tokens,
1504  1,
1505  aligned_head_dim,
1506  scale);
1507  }
1508 }
1509 
1510 /**
1511  * @brief WARNING: This is NOT true flash attention!
1512  *
1513  * This function is named "flash" but implements regular attention with O(n) complexity.
1514  * It's kept for reference and as a fallback.
1515  *
1516  * TRUE flash attention is implemented in attention_flash_true.c
1517  * @test test_kv_cache_attention.py::TestKVCacheAttention::test_regular_decode
1518  * @test test_attention.py::TestAttentionForward::test_regular_decode
1519  *
1520  * Regular attention decode (score-matrix version) for fallback.
1521  *
1522  * After changes: make test
1523  */
1525  const float *k_cache,
1526  const float *v_cache,
1527  float *out_token,
1528  int num_heads,
1529  int num_kv_heads,
1530  int kv_tokens,
1531  int cache_capacity,
1532  int head_dim,
1533  int aligned_head_dim)
1534 {
1535  if (!q_token || !k_cache || !v_cache || !out_token) {
1536  return;
1537  }
1538  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1539  return;
1540  }
1541  if (kv_tokens > cache_capacity) {
1542  return;
1543  }
1544 
1545  const float scale = 1.0f / sqrtf((float)head_dim);
1546  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1547 
1548  // Select SIMD implementation based on compile-time CPU features
1549 #if defined(__AVX512F__)
1550  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx512
1551 #elif defined(__AVX2__)
1552  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx2
1553 #elif defined(__AVX__)
1554  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx
1555 #else
1556  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal
1557 #endif
1558 
1559 #pragma omp parallel for schedule(static) if(num_heads > 1)
1560  for (int h = 0; h < num_heads; ++h) {
1561  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1562  const float *q_vec = q_token + (size_t)h * (size_t)aligned_head_dim;
1563  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1564  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1565  float *out_vec = out_token + (size_t)h * (size_t)aligned_head_dim;
1566 
1567  FLASH_QUERY_IMPL_DECODE(q_vec, k_head, v_head,
1568  kv_tokens, head_dim, aligned_head_dim,
1569  scale, out_vec);
1570  }
1571 
1572 #undef FLASH_QUERY_IMPL_DECODE
1573 }
1574 
1575 // ============================================================================
1576 // ATTENTION BACKWARD - Causal, Head-Major, GQA-aware
1577 // ============================================================================
1578 //
1579 // Backward pass for scaled dot-product attention with causal mask.
1580 //
1581 // Given:
1582 // d_output: gradient from the layer above [num_heads, T, head_dim]
1583 // q, k, v: saved activations from forward pass
1584 // attn_weights: saved softmax output from forward [num_heads, T, T]
1585 //
1586 // Computes:
1587 // d_q: gradient w.r.t. queries [num_heads, T, head_dim]
1588 // d_k: gradient w.r.t. keys [num_kv_heads, T, head_dim]
1589 // d_v: gradient w.r.t. values [num_kv_heads, T, head_dim]
1590 //
1591 // Math derivation:
1592 // Forward: scores = Q @ K^T / sqrt(d)
1593 // weights = causal_softmax(scores)
1594 // output = weights @ V
1595 //
1596 // Backward through V multiply:
1597 // d_weights = d_output @ V^T [H, T, T]
1598 // d_v = weights^T @ d_output [H_kv, T, d]
1599 //
1600 // Backward through softmax:
1601 // d_scores = softmax_backward(d_weights, weights)
1602 //
1603 // Backward through Q @ K^T:
1604 // d_q = d_scores @ K / sqrt(d) [H, T, d]
1605 // d_k = d_scores^T @ Q / sqrt(d) [H_kv, T, d]
1606 //
1607 // For GQA: multiple query heads share the same KV head, so we accumulate
1608 // gradients from all query heads that map to each KV head.
1609 //
1610 /**
1611  * BF16 attention backward with caller-provided scratch buffers
1612  * @test bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_backward
1613  *
1614  * Accepts BF16 inputs, converts to FP32, runs FP32 backward.
1615  * Caller provides scratch buffers (no per-call malloc).
1616  *
1617  * After changes: make test
1618  */
1620  const uint16_t *d_output, // [num_heads, T, aligned_head_dim]
1621  float *d_x, // [num_heads, T, aligned_head_dim]
1622  const uint16_t *q, // [num_heads, T, aligned_head_dim]
1623  const uint16_t *k, // [num_kv_heads, T, aligned_head_dim]
1624  const uint16_t *v, // [num_kv_heads, T, aligned_head_dim]
1625  const float *attn_weights, // [num_heads, T, aligned_context_window]
1626  float *d_q, // [num_heads, T, aligned_head_dim] output
1627  float *d_k, // [num_kv_heads, T, aligned_head_dim] output
1628  float *d_v, // [num_kv_heads, T, aligned_head_dim] output
1629  float *d_scores, // [num_heads, T, aligned_context_window] scratch
1630  int num_heads,
1631  int num_kv_heads,
1632  int num_tokens,
1633  int head_dim,
1634  int aligned_head_dim,
1635  int aligned_context_window,
1636  float *scratch_d_output,
1637  float *scratch_q,
1638  float *scratch_k,
1639  float *scratch_v)
1640 {
1641  (void)d_x;
1642  const size_t head_elems = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
1643  const size_t kv_elems = (size_t)num_kv_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
1644 
1645  if (!scratch_d_output || !scratch_q || !scratch_k || !scratch_v) return;
1646 
1647  convert_bf16_tensor_to_buf(d_output, scratch_d_output, head_elems);
1648  convert_bf16_tensor_to_buf(q, scratch_q, head_elems);
1649  convert_bf16_tensor_to_buf(k, scratch_k, kv_elems);
1650  convert_bf16_tensor_to_buf(v, scratch_v, kv_elems);
1651 
1652  attention_backward_causal_head_major_gqa(scratch_d_output, scratch_q, scratch_k, scratch_v,
1653  attn_weights,
1654  d_q, d_k, d_v, d_scores,
1655  num_heads, num_kv_heads,
1656  num_tokens, head_dim,
1657  aligned_head_dim, aligned_context_window);
1658  /* No free - caller owns scratch buffers */
1659 }
1660 
1661 /**
1662  * GQA causal attention backward (score-matrix version)
1663  * @test test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward
1664  * @test test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate
1665  * @test test_parity.py::test_attention_backward_parity
1666  *
1667  * Computes dQ, dK, dV given dOutput and attention weights.
1668  * Supports grouped-query attention with head broadcasting.
1669  *
1670  * After changes: make test && make llamacpp-parity-full
1671  */
1673  const float *d_output, // [num_heads, T, aligned_head_dim]
1674  const float *q, // [num_heads, T, aligned_head_dim]
1675  const float *k, // [num_kv_heads, T, aligned_head_dim]
1676  const float *v, // [num_kv_heads, T, aligned_head_dim]
1677  const float *attn_weights, // [num_heads, T, aligned_context_window]
1678  float *d_q, // [num_heads, T, aligned_head_dim] output
1679  float *d_k, // [num_kv_heads, T, aligned_head_dim] output
1680  float *d_v, // [num_kv_heads, T, aligned_head_dim] output
1681  float *d_scores, // [num_heads, T, aligned_context_window] scratch
1682  int num_heads,
1683  int num_kv_heads,
1684  int num_tokens,
1685  int head_dim,
1686  int aligned_head_dim,
1687  int aligned_context_window)
1688 {
1689  const float scale = 1.0f / sqrtf((float)head_dim);
1690  int T = num_tokens;
1691  int H = num_heads;
1692  int H_kv = num_kv_heads;
1693  int hd = head_dim;
1694  int ad = aligned_head_dim;
1695  int aw = aligned_context_window;
1696 
1697  const size_t d_q_elems = (size_t)H * (size_t)T * (size_t)ad;
1698  const size_t kv_elems = (size_t)H_kv * (size_t)T * (size_t)ad;
1699  /* Zero the aligned outputs so padded lanes never leak garbage to downstream GEMMs. */
1700  for (size_t idx = 0; idx < d_q_elems; ++idx) {
1701  d_q[idx] = 0.0f;
1702  }
1703  for (size_t idx = 0; idx < kv_elems; ++idx) {
1704  d_k[idx] = 0.0f;
1705  d_v[idx] = 0.0f;
1706  }
1707 
1708  // Process each query head
1709  for (int h = 0; h < H; ++h) {
1710  // Which KV head does this query head use?
1711  int kv_h = (int)((long long)h * (long long)H_kv / (long long)H);
1712 
1713  // ----------------------------------------------------------------
1714  // Step 1: d_weights = d_output @ V^T and d_v += weights^T @ d_output
1715  // ----------------------------------------------------------------
1716  // For each query position i, compute d_weights[i, j] for j <= i
1717  // and accumulate d_v[j] contributions
1718 
1719  for (int i = 0; i < T; ++i) {
1720  size_t d_out_base = qkv_index(h, i, 0, T, ad);
1721 
1722  for (int j = 0; j <= i; ++j) {
1723  size_t v_base = qkv_index(kv_h, j, 0, T, ad);
1724  size_t w_idx = score_index(h, i, j, aw);
1725  float w = attn_weights[w_idx];
1726 
1727  // d_weights[h, i, j] = d_output[h, i, :] @ v[kv_h, j, :]^T
1728  float dot = 0.0f;
1729  for (int dd = 0; dd < hd; ++dd) {
1730  dot += d_output[d_out_base + dd] * v[v_base + dd];
1731  }
1732  d_scores[w_idx] = dot;
1733 
1734  // d_v[kv_h, j, :] += weights[h, i, j] * d_output[h, i, :]
1735  for (int dd = 0; dd < hd; ++dd) {
1736  d_v[v_base + dd] += w * d_output[d_out_base + dd];
1737  }
1738  }
1739 
1740  // Zero out upper triangle of d_scores
1741  for (int j = i + 1; j < T; ++j) {
1742  d_scores[score_index(h, i, j, aw)] = 0.0f;
1743  }
1744  /* Scores scratch uses aligned_context_window, zero the padded columns. */
1745  for (int j = T; j < aw; ++j) {
1746  d_scores[score_index(h, i, j, aw)] = 0.0f;
1747  }
1748  }
1749 
1750  // ----------------------------------------------------------------
1751  // Step 2: Backward through softmax (in-place on d_scores for this head)
1752  // ----------------------------------------------------------------
1753  // d_scores = softmax_backward(d_scores, attn_weights)
1754  // Formula: d_score[i,j] = w[i,j] * (d_w[i,j] - sum_k(w[i,k] * d_w[i,k]))
1755 
1756  for (int i = 0; i < T; ++i) {
1757  int base = h * aw * aw + i * aw;
1758 
1759  // Compute dot product: sum_j w[i,j] * d_w[i,j]
1760  float dot_product = 0.0f;
1761  for (int j = 0; j <= i; ++j) {
1762  float wt = attn_weights[base + j];
1763  float dw = d_scores[base + j];
1764  dot_product += wt * dw;
1765  }
1766 
1767  // Apply softmax backward formula
1768  for (int j = 0; j <= i; ++j) {
1769  float wt = attn_weights[base + j];
1770  float dw = d_scores[base + j];
1771  d_scores[base + j] = wt * (dw - dot_product);
1772  }
1773  }
1774 
1775  // ----------------------------------------------------------------
1776  // Step 3: d_q = d_scores @ K * scale
1777  // d_k += d_scores^T @ Q * scale
1778  // ----------------------------------------------------------------
1779 
1780  for (int i = 0; i < T; ++i) {
1781  size_t d_q_base = qkv_index(h, i, 0, T, ad);
1782  size_t q_base = qkv_index(h, i, 0, T, ad);
1783 
1784  // d_q[h, i, :] = sum_j d_scores[h, i, j] * k[kv_h, j, :] * scale
1785  // d_k[kv_h, j, :] += d_scores[h, i, j] * q[h, i, :] * scale
1786  for (int j = 0; j <= i; ++j) {
1787  size_t k_base = qkv_index(kv_h, j, 0, T, ad);
1788  size_t d_k_base = qkv_index(kv_h, j, 0, T, ad);
1789  float ds = d_scores[score_index(h, i, j, aw)] * scale;
1790 
1791  for (int dd = 0; dd < hd; ++dd) {
1792  d_q[d_q_base + dd] += ds * k[k_base + dd];
1793  d_k[d_k_base + dd] += ds * q[q_base + dd];
1794  }
1795  }
1796  }
1797  }
1798 }
1799 
1800 /**
1801  * Causal attention backward (non-GQA version)
1802  * @test test_attention_backward.py::TestAttentionBackward::test_backward
1803  * @test test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate
1804  * @test test_parity.py::test_attention_backward_parity
1805  *
1806  * Non-GQA version where num_heads == num_kv_heads.
1807  * Simpler than GQA, no head broadcasting needed.
1808  *
1809  * After changes: make test && make llamacpp-parity-full
1810  */
1812  const float *d_output,
1813  const float *q,
1814  const float *k,
1815  const float *v,
1816  const float *attn_weights,
1817  float *d_q,
1818  float *d_k,
1819  float *d_v,
1820  float *d_scores,
1821  int num_heads,
1822  int num_tokens,
1823  int head_dim,
1824  int aligned_head_dim,
1825  int aligned_context_window)
1826 {
1828  d_output, q, k, v, attn_weights,
1829  d_q, d_k, d_v, d_scores,
1830  num_heads, num_heads, // num_kv_heads == num_heads
1831  num_tokens, head_dim, aligned_head_dim, aligned_context_window);
1832 }
static size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)
void attention_forward_causal_head_major_gqa_bf16(const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v)
void attention_forward_causal_head_major_gqa_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void attention_backward_causal_head_major_gqa_bf16(const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v)
void attention_forward_causal_head_major(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
void attention_forward_causal_head_major_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
#define FLASH_QUERY_IMPL
static size_t score_index(int h, int i, int j, int aligned_context_window)
#define SLIDING_DECODE_IMPL
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
static void attention_flash_query_sliding(const float *q_vec, const float *k_head, const float *v_head, int query_pos, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec, int sliding_window)
static void attention_flash_query_causal(const float *q_vec, const float *k_head, const float *v_head, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
#define SLIDING_FLASH_IMPL
#define FLASH_QUERY_IMPL_DECODE
void attention_backward_causal_head_major(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
static void convert_bf16_tensor_to_buf(const uint16_t *src, float *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)
int32_t float * score
Definition: tokenizer.h:327