← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_kernels.c File Reference

Attention score/softmax/output kernels with SIMD (SSE/AVX/AVX512) More...

#include "bf16_utils.h"
#include "ckernel_engine.h"
#include <math.h>
#include <stdlib.h>

Go to the source code of this file.

Macros

#define FLASH_QUERY_IMPL   attention_flash_query_causal
 
#define FLASH_QUERY_IMPL   attention_flash_query_causal
 
#define FLASH_QUERY_IMPL_DECODE   attention_flash_query_causal
 
#define SLIDING_DECODE_IMPL   attention_flash_query_sliding
 
#define SLIDING_FLASH_IMPL   attention_flash_query_sliding
 

Functions

void attention_backward_causal_head_major (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_backward_causal_head_major_gqa (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_backward_causal_head_major_gqa_bf16 (const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v)
 
static void attention_flash_query_causal (const float *q_vec, const float *k_head, const float *v_head, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec)
 
static void attention_flash_query_sliding (const float *q_vec, const float *k_head, const float *v_head, int query_pos, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float *out_vec, int sliding_window)
 
void attention_forward_causal_head_major (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_gqa (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_gqa_bf16 (const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v)
 
void attention_forward_causal_head_major_gqa_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_gqa_flash (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
 
void attention_forward_causal_head_major_gqa_flash_strided (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
 
void attention_forward_causal_head_major_gqa_flash_strided_sliding (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
 
void attention_forward_decode_head_major_gqa_flash (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 
void attention_forward_decode_head_major_gqa_flash_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
 
void attention_forward_decode_head_major_gqa_regular (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 WARNING: This is NOT true flash attention! More...
 
static void convert_bf16_tensor_to_buf (const uint16_t *src, float *dst, size_t count)
 
static size_t qkv_index (int h, int t, int d, int num_tokens, int aligned_head_dim)
 
static size_t score_index (int h, int i, int j, int aligned_context_window)
 

Detailed Description

Attention score/softmax/output kernels with SIMD (SSE/AVX/AVX512)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Attention: softmax(Q @ K^T / sqrt(d)) @ V Supports GQA (grouped-query attention) with head broadcasting.

Definition in file attention_kernels.c.

Macro Definition Documentation

◆ FLASH_QUERY_IMPL [1/2]

#define FLASH_QUERY_IMPL   attention_flash_query_causal

◆ FLASH_QUERY_IMPL [2/2]

#define FLASH_QUERY_IMPL   attention_flash_query_causal

◆ FLASH_QUERY_IMPL_DECODE

#define FLASH_QUERY_IMPL_DECODE   attention_flash_query_causal

◆ SLIDING_DECODE_IMPL

#define SLIDING_DECODE_IMPL   attention_flash_query_sliding

◆ SLIDING_FLASH_IMPL

#define SLIDING_FLASH_IMPL   attention_flash_query_sliding

Function Documentation

◆ attention_backward_causal_head_major()

void attention_backward_causal_head_major ( const float *  d_output,
const float *  q,
const float *  k,
const float *  v,
const float *  attn_weights,
float *  d_q,
float *  d_k,
float *  d_v,
float *  d_scores,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

Causal attention backward (non-GQA version)

Test:

test_attention_backward.py::TestAttentionBackward::test_backward

test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate

test_parity.py::test_attention_backward_parity

Non-GQA version where num_heads == num_kv_heads. Simpler than GQA, no head broadcasting needed.

After changes: make test && make llamacpp-parity-full

Definition at line 1811 of file attention_kernels.c.

1826 {
1828  d_output, q, k, v, attn_weights,
1829  d_q, d_k, d_v, d_scores,
1830  num_heads, num_heads, // num_kv_heads == num_heads
1831  num_tokens, head_dim, aligned_head_dim, aligned_context_window);
1832 }
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

References attention_backward_causal_head_major_gqa().

◆ attention_backward_causal_head_major_gqa()

void attention_backward_causal_head_major_gqa ( const float *  d_output,
const float *  q,
const float *  k,
const float *  v,
const float *  attn_weights,
float *  d_q,
float *  d_k,
float *  d_v,
float *  d_scores,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

GQA causal attention backward (score-matrix version)

Test:

test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward

test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate

test_parity.py::test_attention_backward_parity

Computes dQ, dK, dV given dOutput and attention weights. Supports grouped-query attention with head broadcasting.

After changes: make test && make llamacpp-parity-full

Definition at line 1672 of file attention_kernels.c.

1688 {
1689  const float scale = 1.0f / sqrtf((float)head_dim);
1690  int T = num_tokens;
1691  int H = num_heads;
1692  int H_kv = num_kv_heads;
1693  int hd = head_dim;
1694  int ad = aligned_head_dim;
1695  int aw = aligned_context_window;
1696 
1697  const size_t d_q_elems = (size_t)H * (size_t)T * (size_t)ad;
1698  const size_t kv_elems = (size_t)H_kv * (size_t)T * (size_t)ad;
1699  /* Zero the aligned outputs so padded lanes never leak garbage to downstream GEMMs. */
1700  for (size_t idx = 0; idx < d_q_elems; ++idx) {
1701  d_q[idx] = 0.0f;
1702  }
1703  for (size_t idx = 0; idx < kv_elems; ++idx) {
1704  d_k[idx] = 0.0f;
1705  d_v[idx] = 0.0f;
1706  }
1707 
1708  // Process each query head
1709  for (int h = 0; h < H; ++h) {
1710  // Which KV head does this query head use?
1711  int kv_h = (int)((long long)h * (long long)H_kv / (long long)H);
1712 
1713  // ----------------------------------------------------------------
1714  // Step 1: d_weights = d_output @ V^T and d_v += weights^T @ d_output
1715  // ----------------------------------------------------------------
1716  // For each query position i, compute d_weights[i, j] for j <= i
1717  // and accumulate d_v[j] contributions
1718 
1719  for (int i = 0; i < T; ++i) {
1720  size_t d_out_base = qkv_index(h, i, 0, T, ad);
1721 
1722  for (int j = 0; j <= i; ++j) {
1723  size_t v_base = qkv_index(kv_h, j, 0, T, ad);
1724  size_t w_idx = score_index(h, i, j, aw);
1725  float w = attn_weights[w_idx];
1726 
1727  // d_weights[h, i, j] = d_output[h, i, :] @ v[kv_h, j, :]^T
1728  float dot = 0.0f;
1729  for (int dd = 0; dd < hd; ++dd) {
1730  dot += d_output[d_out_base + dd] * v[v_base + dd];
1731  }
1732  d_scores[w_idx] = dot;
1733 
1734  // d_v[kv_h, j, :] += weights[h, i, j] * d_output[h, i, :]
1735  for (int dd = 0; dd < hd; ++dd) {
1736  d_v[v_base + dd] += w * d_output[d_out_base + dd];
1737  }
1738  }
1739 
1740  // Zero out upper triangle of d_scores
1741  for (int j = i + 1; j < T; ++j) {
1742  d_scores[score_index(h, i, j, aw)] = 0.0f;
1743  }
1744  /* Scores scratch uses aligned_context_window, zero the padded columns. */
1745  for (int j = T; j < aw; ++j) {
1746  d_scores[score_index(h, i, j, aw)] = 0.0f;
1747  }
1748  }
1749 
1750  // ----------------------------------------------------------------
1751  // Step 2: Backward through softmax (in-place on d_scores for this head)
1752  // ----------------------------------------------------------------
1753  // d_scores = softmax_backward(d_scores, attn_weights)
1754  // Formula: d_score[i,j] = w[i,j] * (d_w[i,j] - sum_k(w[i,k] * d_w[i,k]))
1755 
1756  for (int i = 0; i < T; ++i) {
1757  int base = h * aw * aw + i * aw;
1758 
1759  // Compute dot product: sum_j w[i,j] * d_w[i,j]
1760  float dot_product = 0.0f;
1761  for (int j = 0; j <= i; ++j) {
1762  float wt = attn_weights[base + j];
1763  float dw = d_scores[base + j];
1764  dot_product += wt * dw;
1765  }
1766 
1767  // Apply softmax backward formula
1768  for (int j = 0; j <= i; ++j) {
1769  float wt = attn_weights[base + j];
1770  float dw = d_scores[base + j];
1771  d_scores[base + j] = wt * (dw - dot_product);
1772  }
1773  }
1774 
1775  // ----------------------------------------------------------------
1776  // Step 3: d_q = d_scores @ K * scale
1777  // d_k += d_scores^T @ Q * scale
1778  // ----------------------------------------------------------------
1779 
1780  for (int i = 0; i < T; ++i) {
1781  size_t d_q_base = qkv_index(h, i, 0, T, ad);
1782  size_t q_base = qkv_index(h, i, 0, T, ad);
1783 
1784  // d_q[h, i, :] = sum_j d_scores[h, i, j] * k[kv_h, j, :] * scale
1785  // d_k[kv_h, j, :] += d_scores[h, i, j] * q[h, i, :] * scale
1786  for (int j = 0; j <= i; ++j) {
1787  size_t k_base = qkv_index(kv_h, j, 0, T, ad);
1788  size_t d_k_base = qkv_index(kv_h, j, 0, T, ad);
1789  float ds = d_scores[score_index(h, i, j, aw)] * scale;
1790 
1791  for (int dd = 0; dd < hd; ++dd) {
1792  d_q[d_q_base + dd] += ds * k[k_base + dd];
1793  d_k[d_k_base + dd] += ds * q[q_base + dd];
1794  }
1795  }
1796  }
1797  }
1798 }
static size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)
static size_t score_index(int h, int i, int j, int aligned_context_window)

References qkv_index(), and score_index().

Referenced by attention_backward_causal_head_major(), attention_backward_causal_head_major_gqa_bf16(), and ck_layer_backward_rmsnorm_swiglu().

◆ attention_backward_causal_head_major_gqa_bf16()

void attention_backward_causal_head_major_gqa_bf16 ( const uint16_t *  d_output,
float *  d_x,
const uint16_t *  q,
const uint16_t *  k,
const uint16_t *  v,
const float *  attn_weights,
float *  d_q,
float *  d_k,
float *  d_v,
float *  d_scores,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window,
float *  scratch_d_output,
float *  scratch_q,
float *  scratch_k,
float *  scratch_v 
)

BF16 attention backward with caller-provided scratch buffers

Test:
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_backward

Accepts BF16 inputs, converts to FP32, runs FP32 backward. Caller provides scratch buffers (no per-call malloc).

After changes: make test

Definition at line 1619 of file attention_kernels.c.

1640 {
1641  (void)d_x;
1642  const size_t head_elems = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
1643  const size_t kv_elems = (size_t)num_kv_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
1644 
1645  if (!scratch_d_output || !scratch_q || !scratch_k || !scratch_v) return;
1646 
1647  convert_bf16_tensor_to_buf(d_output, scratch_d_output, head_elems);
1648  convert_bf16_tensor_to_buf(q, scratch_q, head_elems);
1649  convert_bf16_tensor_to_buf(k, scratch_k, kv_elems);
1650  convert_bf16_tensor_to_buf(v, scratch_v, kv_elems);
1651 
1652  attention_backward_causal_head_major_gqa(scratch_d_output, scratch_q, scratch_k, scratch_v,
1653  attn_weights,
1654  d_q, d_k, d_v, d_scores,
1655  num_heads, num_kv_heads,
1656  num_tokens, head_dim,
1657  aligned_head_dim, aligned_context_window);
1658  /* No free - caller owns scratch buffers */
1659 }
static void convert_bf16_tensor_to_buf(const uint16_t *src, float *dst, size_t count)

References attention_backward_causal_head_major_gqa(), and convert_bf16_tensor_to_buf().

◆ attention_flash_query_causal()

static void attention_flash_query_causal ( const float *  q_vec,
const float *  k_head,
const float *  v_head,
int  kv_tokens,
int  head_dim,
int  aligned_head_dim,
float  scale,
float *  out_vec 
)
static

Definition at line 730 of file attention_kernels.c.

738 {
739  // Online softmax:
740  // m = running max, s = running sum(exp(score - m))
741  // out = sum(exp(score - m) * v)
742  float m = -INFINITY;
743  float s = 0.0f;
744 
745  for (int d = 0; d < head_dim; ++d) {
746  out_vec[d] = 0.0f;
747  }
748 
749  for (int j = 0; j < kv_tokens; ++j) {
750  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
751  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
752 
753  float dot = 0.0f;
754  for (int d = 0; d < head_dim; ++d) {
755  dot += q_vec[d] * k_vec[d];
756  }
757  float score = dot * scale;
758 
759  if (score > m) {
760  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
761  s *= exp_m;
762  for (int d = 0; d < head_dim; ++d) {
763  out_vec[d] *= exp_m;
764  }
765  s += 1.0f;
766  for (int d = 0; d < head_dim; ++d) {
767  out_vec[d] += v_vec[d];
768  }
769  m = score;
770  } else {
771  float e = expf(score - m);
772  s += e;
773  for (int d = 0; d < head_dim; ++d) {
774  out_vec[d] += e * v_vec[d];
775  }
776  }
777  }
778 
779  float inv_s = 1.0f / s;
780  for (int d = 0; d < head_dim; ++d) {
781  out_vec[d] *= inv_s;
782  }
783  for (int d = head_dim; d < aligned_head_dim; ++d) {
784  out_vec[d] = 0.0f;
785  }
786 }
int32_t float * score
Definition: tokenizer.h:327

References score.

◆ attention_flash_query_sliding()

static void attention_flash_query_sliding ( const float *  q_vec,
const float *  k_head,
const float *  v_head,
int  query_pos,
int  kv_tokens,
int  head_dim,
int  aligned_head_dim,
float  scale,
float *  out_vec,
int  sliding_window 
)
static

Definition at line 1243 of file attention_kernels.c.

1253 {
1254  float m = -INFINITY;
1255  float s = 0.0f;
1256 
1257  int window_start = 0;
1258  if (sliding_window > 0) {
1259  window_start = query_pos - sliding_window + 1;
1260  if (window_start < 0) window_start = 0;
1261  }
1262 
1263  for (int d = 0; d < head_dim; ++d) {
1264  out_vec[d] = 0.0f;
1265  }
1266 
1267  int effective_kv_end = query_pos < kv_tokens ? query_pos : kv_tokens - 1;
1268  for (int j = window_start; j <= effective_kv_end; ++j) {
1269  const float *k_vec = k_head + (size_t)j * (size_t)aligned_head_dim;
1270  const float *v_vec = v_head + (size_t)j * (size_t)aligned_head_dim;
1271 
1272  float dot = 0.0f;
1273  for (int d = 0; d < head_dim; ++d) {
1274  dot += q_vec[d] * k_vec[d];
1275  }
1276  float score = dot * scale;
1277 
1278  if (score > m) {
1279  float exp_m = (m == -INFINITY) ? 0.0f : expf(m - score);
1280  s *= exp_m;
1281  for (int d = 0; d < head_dim; ++d) {
1282  out_vec[d] *= exp_m;
1283  }
1284  s += 1.0f;
1285  for (int d = 0; d < head_dim; ++d) {
1286  out_vec[d] += v_vec[d];
1287  }
1288  m = score;
1289  } else {
1290  float e = expf(score - m);
1291  s += e;
1292  for (int d = 0; d < head_dim; ++d) {
1293  out_vec[d] += e * v_vec[d];
1294  }
1295  }
1296  }
1297 
1298  float inv_s = 1.0f / s;
1299  for (int d = 0; d < head_dim; ++d) {
1300  out_vec[d] *= inv_s;
1301  }
1302  for (int d = head_dim; d < aligned_head_dim; ++d) {
1303  out_vec[d] = 0.0f;
1304  }
1305 }

References score.

◆ attention_forward_causal_head_major()

void attention_forward_causal_head_major ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

Causal attention forward (score-matrix version)

Test:

test_attention.py::TestAttentionForward::test_causal_forward

test_attention.py::TestAttentionForward::test_gqa_broadcast

test_attention.py::TestAttentionForward::test_exact_vs_fast

test_parity.py::test_attention_parity

Computes softmax(Q @ K^T / sqrt(d)) @ V with causal masking. Uses O(N^2) memory for scores matrix.

After changes: make test && make llamacpp-parity-full

Definition at line 70 of file attention_kernels.c.

80 {
81  const float scale = 1.0f / sqrtf((float)head_dim);
82 
83  // Phase 1: compute scaled dot-product scores Q·K^T / sqrt(d_k),
84  // lower triangle only (j <= i).
85  for (int h = 0; h < num_heads; ++h) {
86  for (int i = 0; i < num_tokens; ++i) {
87  for (int j = 0; j <= i; ++j) {
88  float dot = 0.0f;
89  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
90  size_t base_k = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
91 
92  for (int d = 0; d < head_dim; ++d) {
93  dot += q[base_q + d] * k[base_k + d];
94  }
95 
96  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
97  }
98 
99  // Ensure upper triangle is zeroed so there are no stale values
100  // before the softmax kernel runs.
101  for (int j = i + 1; j < num_tokens; ++j) {
102  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
103  }
104  }
105  }
106 
107  // Phase 2: apply causal row-wise softmax in-place over j <= i.
109  num_heads,
110  num_tokens,
111  aligned_context_window);
112 
113  // Phase 3: attention weights · V.
114  for (int h = 0; h < num_heads; ++h) {
115  for (int i = 0; i < num_tokens; ++i) {
116  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
117 
118  // Zero the full aligned head slice so padded dims stay clean.
119  for (int d = 0; d < aligned_head_dim; ++d) {
120  output[out_base + d] = 0.0f;
121  }
122 
123  // Weighted sum over causal positions.
124  for (int j = 0; j <= i; ++j) {
125  float w = scores[score_index(h, i, j, aligned_context_window)];
126  size_t v_base = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
127 
128  for (int d = 0; d < head_dim; ++d) {
129  output[out_base + d] += w * v[v_base + d];
130  }
131  }
132  }
133  }
134 }
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)

References causal_softmax_head_major(), qkv_index(), and score_index().

◆ attention_forward_causal_head_major_exact()

void attention_forward_causal_head_major_exact ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

Causal attention forward (exact version using stdlib expf)

Test:

test_attention.py::TestAttentionForward::test_exact_single

test_attention.py::TestAttentionForward::test_exact_vs_fast

Uses standard library expf for numerical accuracy reference. Slower but provides maximum accuracy.

After changes: make test

Definition at line 146 of file attention_kernels.c.

156 {
157  const float scale = 1.0f / sqrtf((float)head_dim);
158 
159  // Phase 1: compute scaled dot-product scores Q·K^T / sqrt(d_k),
160  // lower triangle only (j <= i).
161  for (int h = 0; h < num_heads; ++h) {
162  for (int i = 0; i < num_tokens; ++i) {
163  for (int j = 0; j <= i; ++j) {
164  float dot = 0.0f;
165  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
166  size_t base_k = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
167 
168  for (int d = 0; d < head_dim; ++d) {
169  dot += q[base_q + d] * k[base_k + d];
170  }
171 
172  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
173  }
174 
175  // Ensure upper triangle is zeroed so there are no stale values
176  // before the softmax kernel runs.
177  for (int j = i + 1; j < num_tokens; ++j) {
178  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
179  }
180  }
181  }
182 
183  // Phase 2: apply causal row-wise softmax using exact expf.
185  num_heads,
186  num_tokens,
187  aligned_context_window);
188 
189  // Phase 3: attention weights · V.
190  for (int h = 0; h < num_heads; ++h) {
191  for (int i = 0; i < num_tokens; ++i) {
192  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
193 
194  // Zero the full aligned head slice so padded dims stay clean.
195  for (int d = 0; d < aligned_head_dim; ++d) {
196  output[out_base + d] = 0.0f;
197  }
198 
199  // Weighted sum over causal positions.
200  for (int j = 0; j <= i; ++j) {
201  float w = scores[score_index(h, i, j, aligned_context_window)];
202  size_t v_base = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
203 
204  for (int d = 0; d < head_dim; ++d) {
205  output[out_base + d] += w * v[v_base + d];
206  }
207  }
208  }
209  }
210 }
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)

References causal_softmax_head_major_exact(), qkv_index(), and score_index().

◆ attention_forward_causal_head_major_gqa()

void attention_forward_causal_head_major_gqa ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

GQA causal attention forward (score-matrix version)

Test:

test_attention.py::TestAttentionForward::test_gqa_forward

test_attention.py::TestAttentionForward::test_gqa_broadcast

test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward

test_parity.py::test_attention_gqa_parity

Grouped-query attention: Q has num_heads, K/V have num_kv_heads. Each query head maps to a KV head via ratio.

After changes: make test && make llamacpp-parity-full

Definition at line 224 of file attention_kernels.c.

235 {
236  const float scale = 1.0f / sqrtf((float)head_dim);
237 
238  for (int h = 0; h < num_heads; ++h) {
239  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
240  for (int i = 0; i < num_tokens; ++i) {
241  for (int j = 0; j <= i; ++j) {
242  float dot = 0.0f;
243  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
244  size_t base_k = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
245 
246  for (int d = 0; d < head_dim; ++d) {
247  dot += q[base_q + d] * k[base_k + d];
248  }
249 
250  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
251  }
252 
253  for (int j = i + 1; j < num_tokens; ++j) {
254  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
255  }
256  }
257  }
258 
260  num_heads,
261  num_tokens,
262  aligned_context_window);
263 
264  for (int h = 0; h < num_heads; ++h) {
265  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
266  for (int i = 0; i < num_tokens; ++i) {
267  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
268  for (int d = 0; d < aligned_head_dim; ++d) {
269  output[out_base + d] = 0.0f;
270  }
271 
272  for (int j = 0; j <= i; ++j) {
273  float w = scores[score_index(h, i, j, aligned_context_window)];
274  size_t v_base = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
275 
276  for (int d = 0; d < head_dim; ++d) {
277  output[out_base + d] += w * v[v_base + d];
278  }
279  }
280  }
281  }
282 }

References causal_softmax_head_major(), qkv_index(), and score_index().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), and ck_layer_forward_rmsnorm_swiglu_ref().

◆ attention_forward_causal_head_major_gqa_bf16()

void attention_forward_causal_head_major_gqa_bf16 ( const uint16_t *  q,
const uint16_t *  k,
const uint16_t *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window,
float *  scratch_q,
float *  scratch_k,
float *  scratch_v 
)

BF16 GQA causal attention forward

Test:

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash

Accepts BF16 inputs, converts to FP32, uses exact softmax. Caller provides scratch buffers (no per-call malloc).

After changes: make test

Definition at line 366 of file attention_kernels.c.

380 {
381  const size_t q_elems = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
382  const size_t kv_elems = (size_t)num_kv_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
383 
384  if (!scratch_q || !scratch_k || !scratch_v) return;
385 
386  convert_bf16_tensor_to_buf(q, scratch_q, q_elems);
387  convert_bf16_tensor_to_buf(k, scratch_k, kv_elems);
388  convert_bf16_tensor_to_buf(v, scratch_v, kv_elems);
389 
390  // Use exact version to avoid fast exp approximation error accumulating
391  // with BF16 precision loss.
392  attention_forward_causal_head_major_gqa_exact(scratch_q, scratch_k, scratch_v,
393  scores, output,
394  num_heads, num_kv_heads,
395  num_tokens, head_dim,
396  aligned_head_dim, aligned_context_window);
397  /* No free - caller owns scratch buffers */
398 }
void attention_forward_causal_head_major_gqa_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

References attention_forward_causal_head_major_gqa_exact(), and convert_bf16_tensor_to_buf().

◆ attention_forward_causal_head_major_gqa_exact()

void attention_forward_causal_head_major_gqa_exact ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

GQA causal attention forward (exact version using stdlib expf)

Test:

test_attention.py::TestAttentionForward::test_gqa_exact

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa

Uses standard library expf for numerical accuracy reference. Used by BF16 wrapper to avoid approximation error accumulation.

After changes: make test

Definition at line 294 of file attention_kernels.c.

305 {
306  const float scale = 1.0f / sqrtf((float)head_dim);
307 
308  for (int h = 0; h < num_heads; ++h) {
309  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
310  for (int i = 0; i < num_tokens; ++i) {
311  for (int j = 0; j <= i; ++j) {
312  float dot = 0.0f;
313  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
314  size_t base_k = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
315 
316  for (int d = 0; d < head_dim; ++d) {
317  dot += q[base_q + d] * k[base_k + d];
318  }
319 
320  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
321  }
322 
323  for (int j = i + 1; j < num_tokens; ++j) {
324  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
325  }
326  }
327  }
328 
329  // Use exact softmax with standard library expf
331  num_heads,
332  num_tokens,
333  aligned_context_window);
334 
335  for (int h = 0; h < num_heads; ++h) {
336  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
337  for (int i = 0; i < num_tokens; ++i) {
338  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
339  for (int d = 0; d < aligned_head_dim; ++d) {
340  output[out_base + d] = 0.0f;
341  }
342 
343  for (int j = 0; j <= i; ++j) {
344  float w = scores[score_index(h, i, j, aligned_context_window)];
345  size_t v_base = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
346 
347  for (int d = 0; d < head_dim; ++d) {
348  output[out_base + d] += w * v[v_base + d];
349  }
350  }
351  }
352  }
353 }

References causal_softmax_head_major_exact(), qkv_index(), and score_index().

Referenced by attention_forward_causal_head_major_gqa_bf16().

◆ attention_forward_causal_head_major_gqa_flash()

void attention_forward_causal_head_major_gqa_flash ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim 
)

Flash attention forward for GQA (prefill, no score materialization)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_forward

test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix

test_flash_attention.py::TestFlashAttention::test_flash_gqa

test_attention.py::TestAttentionForward::test_flash_forward

Online softmax with streaming KV. O(N) memory instead of O(N^2). For prefill: all tokens attend to previous tokens.

After changes: make test && make llamacpp-parity-full

Definition at line 800 of file attention_kernels.c.

809 {
810  if (!q || !k || !v || !output) {
811  return;
812  }
813  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
814  return;
815  }
816 
817  const float scale = 1.0f / sqrtf((float)head_dim);
818  const int T = num_tokens;
819 
820  // Select SIMD implementation based on compile-time CPU features
821 #if defined(__AVX512F__)
822  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
823 #elif defined(__AVX2__)
824  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
825 #elif defined(__AVX__)
826  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
827 #else
828  #define FLASH_QUERY_IMPL attention_flash_query_causal
829 #endif
830 
831  for (int h = 0; h < num_heads; ++h) {
832  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
833  const float *k_head = k + (size_t)kv_head * (size_t)T * (size_t)aligned_head_dim;
834  const float *v_head = v + (size_t)kv_head * (size_t)T * (size_t)aligned_head_dim;
835 
836  for (int i = 0; i < T; ++i) {
837  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
838  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
839  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
840  /*kv_tokens=*/i + 1,
841  head_dim, aligned_head_dim,
842  scale, out_vec);
843  }
844  }
845 
846 #undef FLASH_QUERY_IMPL
847 }
#define FLASH_QUERY_IMPL

References FLASH_QUERY_IMPL, and qkv_index().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().

◆ attention_forward_causal_head_major_gqa_flash_strided()

void attention_forward_causal_head_major_gqa_flash_strided ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens 
)

Flash attention forward with custom KV stride (for KV cache)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention

Variant with configurable kv_stride_tokens for KV cache layouts where K/V may not be contiguous in memory.

After changes: make test

Definition at line 859 of file attention_kernels.c.

869 {
870  if (!q || !k || !v || !output) {
871  return;
872  }
873  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
874  return;
875  }
876  if (kv_stride_tokens < num_tokens) {
877  return;
878  }
879 
880  const float scale = 1.0f / sqrtf((float)head_dim);
881  const int T = num_tokens;
882  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
883 
884  // Select SIMD implementation based on compile-time CPU features
885 #if defined(__AVX512F__)
886  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
887 #elif defined(__AVX2__)
888  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
889 #elif defined(__AVX__)
890  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
891 #else
892  #define FLASH_QUERY_IMPL attention_flash_query_causal
893 #endif
894 
895  for (int h = 0; h < num_heads; ++h) {
896  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
897  const float *k_head = k + (size_t)kv_head * kv_head_stride;
898  const float *v_head = v + (size_t)kv_head * kv_head_stride;
899 
900  for (int i = 0; i < T; ++i) {
901  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
902  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
903  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
904  /*kv_tokens=*/i + 1,
905  head_dim, aligned_head_dim,
906  scale, out_vec);
907  }
908  }
909 
910 #undef FLASH_QUERY_IMPL
911 }

References FLASH_QUERY_IMPL, and qkv_index().

Referenced by ck_test_attention_causal(), mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().

◆ attention_forward_causal_head_major_gqa_flash_strided_sliding()

void attention_forward_causal_head_major_gqa_flash_strided_sliding ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
int  sliding_window 
)

Flash attention forward with sliding window (prefill)

Test:
test_attention.py::TestAttentionForward::test_sliding_window_prefill

Sliding-window attention for prefill: each token attends to the last W tokens. When sliding_window <= 0, behaves like regular causal attention.

After changes: make test

Definition at line 1316 of file attention_kernels.c.

1328 {
1329  if (!q || !k || !v || !output) {
1330  return;
1331  }
1332  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
1333  return;
1334  }
1335  if (kv_stride_tokens < num_tokens) {
1336  return;
1337  }
1338 
1339  const float scale = 1.0f / sqrtf((float)head_dim);
1340  const int T = num_tokens;
1341  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
1342 
1343 #if defined(__AVX512F__)
1344  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx512
1345 #elif defined(__AVX2__)
1346  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx2
1347 #elif defined(__AVX__)
1348  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx
1349 #else
1350  #define SLIDING_FLASH_IMPL attention_flash_query_sliding
1351 #endif
1352 
1353  for (int h = 0; h < num_heads; ++h) {
1354  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1355  const float *k_head = k + (size_t)kv_head * kv_head_stride;
1356  const float *v_head = v + (size_t)kv_head * kv_head_stride;
1357 
1358  for (int i = 0; i < T; ++i) {
1359  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
1360  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
1361  SLIDING_FLASH_IMPL(q_vec, k_head, v_head,
1362  /*query_pos=*/i,
1363  /*kv_tokens=*/T,
1364  head_dim, aligned_head_dim,
1365  scale, out_vec,
1366  sliding_window);
1367  }
1368  }
1369 
1370 #undef SLIDING_FLASH_IMPL
1371 }
#define SLIDING_FLASH_IMPL

References qkv_index(), and SLIDING_FLASH_IMPL.

Referenced by ck_test_attention_sliding_window().

◆ attention_forward_decode_head_major_gqa_flash()

void attention_forward_decode_head_major_gqa_flash ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Flash attention decode (single token attends to KV cache)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_decode

test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode

test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode

test_attention.py::TestAttentionForward::test_flash_decode

Single query token attends to kv_tokens in KV cache. Uses true flash attention from attention_flash_true.c.

After changes: make test && make llamacpp-parity-full

Definition at line 1467 of file attention_kernels.c.

1477 {
1478  if (!q_token || !k_cache || !v_cache || !out_token) {
1479  return;
1480  }
1481  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1482  return;
1483  }
1484  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1485  return;
1486  }
1487 
1488  const float scale = 1.0f / sqrtf((float)head_dim);
1489  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1490 
1491  for (int h = 0; h < num_heads; ++h) {
1492  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1493  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1494  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1495  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1496  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1497 
1498  attention_flash_decode(out_head,
1499  q_head,
1500  k_head,
1501  v_head,
1502  1,
1503  kv_tokens,
1504  1,
1505  aligned_head_dim,
1506  scale);
1507  }
1508 }
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.

References attention_flash_decode().

Referenced by mega_fused_attention_decode_q5_0(), mega_fused_attention_decode_q5_0_parallel_simd(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ attention_forward_decode_head_major_gqa_flash_sliding()

void attention_forward_decode_head_major_gqa_flash_sliding ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim,
int  sliding_window 
)

Flash attention decode with sliding window

Test:
test_attention.py::TestAttentionForward::test_sliding_window_decode

Single query token attends to the last W tokens in the KV cache. For decode: effective_kv_tokens = min(kv_tokens, sliding_window)

After changes: make test

Definition at line 1382 of file attention_kernels.c.

1394 {
1395  if (!q_token || !k_cache || !v_cache || !out_token) {
1396  return;
1397  }
1398  if (num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
1399  return;
1400  }
1401  if (kv_tokens <= 0 || kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1402  return;
1403  }
1404 
1405  const float scale = 1.0f / sqrtf((float)head_dim);
1406  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1407 
1408  // Compute effective KV tokens based on sliding window
1409  int effective_kv_tokens = kv_tokens;
1410  if (sliding_window > 0 && sliding_window < kv_tokens) {
1411  effective_kv_tokens = sliding_window;
1412  }
1413 
1414  // Guard against empty window (shouldn't happen with kv_tokens >= 1)
1415  if (effective_kv_tokens <= 0) {
1416  return;
1417  }
1418 
1419  // Offset to start reading from the last effective_kv_tokens entries
1420  int kv_start_offset = kv_tokens - effective_kv_tokens;
1421 
1422 #if defined(__AVX512F__)
1423  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx512
1424 #elif defined(__AVX2__)
1425  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx2
1426 #elif defined(__AVX__)
1427  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx
1428 #else
1429  #define SLIDING_DECODE_IMPL attention_flash_query_sliding
1430 #endif
1431 
1432  for (int h = 0; h < num_heads; ++h) {
1433  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1434  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1435  // Offset K/V pointer to start from the first token in the sliding window
1436  const float *k_head = k_cache + (size_t)kv_head * head_stride
1437  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1438  const float *v_head = v_cache + (size_t)kv_head * head_stride
1439  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1440  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1441 
1442  // Use query_pos relative to the windowed KV (last token = effective_kv_tokens - 1)
1443  // sliding_window = 0 since we've already windowed via K/V pointer offset
1444  SLIDING_DECODE_IMPL(q_head, k_head, v_head,
1445  /*query_pos=*/effective_kv_tokens - 1,
1446  /*kv_tokens=*/effective_kv_tokens,
1447  head_dim, aligned_head_dim,
1448  scale, out_head,
1449  /*sliding_window=*/0);
1450  }
1451 
1452 #undef SLIDING_DECODE_IMPL
1453 }
#define SLIDING_DECODE_IMPL

References SLIDING_DECODE_IMPL.

Referenced by ck_test_attention_decode_sliding().

◆ attention_forward_decode_head_major_gqa_regular()

void attention_forward_decode_head_major_gqa_regular ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

WARNING: This is NOT true flash attention!

This function is named "flash" but implements regular attention with O(n) complexity. It's kept for reference and as a fallback.

TRUE flash attention is implemented in attention_flash_true.c

Test:

test_kv_cache_attention.py::TestKVCacheAttention::test_regular_decode

test_attention.py::TestAttentionForward::test_regular_decode

Regular attention decode (score-matrix version) for fallback.

After changes: make test

Definition at line 1524 of file attention_kernels.c.

1534 {
1535  if (!q_token || !k_cache || !v_cache || !out_token) {
1536  return;
1537  }
1538  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1539  return;
1540  }
1541  if (kv_tokens > cache_capacity) {
1542  return;
1543  }
1544 
1545  const float scale = 1.0f / sqrtf((float)head_dim);
1546  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1547 
1548  // Select SIMD implementation based on compile-time CPU features
1549 #if defined(__AVX512F__)
1550  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx512
1551 #elif defined(__AVX2__)
1552  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx2
1553 #elif defined(__AVX__)
1554  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx
1555 #else
1556  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal
1557 #endif
1558 
1559 #pragma omp parallel for schedule(static) if(num_heads > 1)
1560  for (int h = 0; h < num_heads; ++h) {
1561  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1562  const float *q_vec = q_token + (size_t)h * (size_t)aligned_head_dim;
1563  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1564  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1565  float *out_vec = out_token + (size_t)h * (size_t)aligned_head_dim;
1566 
1567  FLASH_QUERY_IMPL_DECODE(q_vec, k_head, v_head,
1568  kv_tokens, head_dim, aligned_head_dim,
1569  scale, out_vec);
1570  }
1571 
1572 #undef FLASH_QUERY_IMPL_DECODE
1573 }
#define FLASH_QUERY_IMPL_DECODE

References FLASH_QUERY_IMPL_DECODE.

Referenced by ck_attention_flash_decode_wrapper(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ convert_bf16_tensor_to_buf()

static void convert_bf16_tensor_to_buf ( const uint16_t *  src,
float *  dst,
size_t  count 
)
static

Definition at line 28 of file attention_kernels.c.

29 {
30  if (!dst || !src) return;
31  bf16_tensor_to_float(src, dst, count);
32 }
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250

References bf16_tensor_to_float().

Referenced by attention_backward_causal_head_major_gqa_bf16(), and attention_forward_causal_head_major_gqa_bf16().

◆ qkv_index()

static size_t qkv_index ( int  h,
int  t,
int  d,
int  num_tokens,
int  aligned_head_dim 
)
inlinestatic

◆ score_index()

static size_t score_index ( int  h,
int  i,
int  j,
int  aligned_context_window 
)
inlinestatic

Definition at line 48 of file attention_kernels.c.

52 {
53  return ((size_t)h * (size_t)aligned_context_window * (size_t)aligned_context_window)
54  + (size_t)i * (size_t)aligned_context_window
55  + (size_t)j;
56 }

Referenced by attention_backward_causal_head_major_gqa(), attention_forward_causal_head_major(), attention_forward_causal_head_major_exact(), attention_forward_causal_head_major_gqa(), and attention_forward_causal_head_major_gqa_exact().