← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_flash_true.c File Reference

Flash-style attention (online softmax, causal, streaming) More...

#include <math.h>
#include <stddef.h>
#include <stdint.h>

Go to the source code of this file.

Macros

#define CK_FLASH_ATTN_FAST_EXP   0
 
#define CK_FLASH_ATTN_TILE_K   32
 

Functions

void attention_flash_cleanup (void)
 Clean up flash attention resources. More...
 
void attention_flash_decode (float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
 Main flash attention function with SIMD dispatch. More...
 
static void attention_flash_decode_scalar (float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
 Scalar flash-style attention (online softmax) More...
 
void attention_flash_init (int max_context, int max_heads, int max_head_dim)
 Initialize flash attention buffers. More...
 
static float ck_expf (float x)
 
static float ck_fast_expf (float x)
 
int ck_flash_attn_choose_tile_k (int D_h)
 
int ck_flash_attn_fast_exp_kind (void)
 
static int ck_flash_attn_tile_k (int D_h)
 
static int max_k_for_query (int t_q, int T_q, int T_k)
 

Detailed Description

Flash-style attention (online softmax, causal, streaming)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Layout: Q/K/V/Out: [T, H, D_h] contiguous

Causal alignment: Queries are assumed to correspond to the last T_q positions in the KV cache. This makes T_q == T_k behave like standard causal prefill, and T_q == 1 behave like decode over a full KV cache.

Notes:

  • This is O(T_k) per query head; it avoids materializing the score matrix.
  • SIMD paths are provided for AVX-512 and AVX.

Definition in file attention_flash_true.c.

Macro Definition Documentation

◆ CK_FLASH_ATTN_FAST_EXP

#define CK_FLASH_ATTN_FAST_EXP   0

Definition at line 44 of file attention_flash_true.c.

◆ CK_FLASH_ATTN_TILE_K

#define CK_FLASH_ATTN_TILE_K   32

Definition at line 40 of file attention_flash_true.c.

Function Documentation

◆ attention_flash_cleanup()

void attention_flash_cleanup ( void  )

Clean up flash attention resources.

Definition at line 739 of file attention_flash_true.c.

739  {
740  // For future optimization: free pre-allocated buffers
741 }

◆ attention_flash_decode()

void attention_flash_decode ( float *  out,
const float *  q,
const float *  k,
const float *  v,
int  T_q,
int  T_k,
int  H,
int  D_h,
float  scale 
)

Main flash attention function with SIMD dispatch.

Parameters
outOutput [T_q, H, D_h]
qQuery [T_q, H, D_h]
kKey [T_k, H, D_h]
vValue [T_k, H, D_h]
T_qNumber of query tokens (1 for decode)
T_kNumber of key/value tokens (context length)
HNumber of heads
D_hHead dimension
scale1/sqrt(D_h)

Definition at line 696 of file attention_flash_true.c.

706 {
707  if (!out || !q || !k || !v) {
708  return;
709  }
710  if (T_q <= 0 || T_k <= 0 || H <= 0 || D_h <= 0) {
711  return;
712  }
713 
714  // Dispatch based on CPU features
715 #if defined(__AVX512F__)
716  attention_flash_decode_avx512(out, q, k, v, T_q, T_k, H, D_h, scale);
717 #elif defined(__AVX__) && !defined(__AVX512F__)
718  attention_flash_decode_avx(out, q, k, v, T_q, T_k, H, D_h, scale);
719 #else
720  attention_flash_decode_scalar(out, q, k, v, T_q, T_k, H, D_h, scale);
721 #endif
722 }
static void attention_flash_decode_scalar(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Scalar flash-style attention (online softmax)

References attention_flash_decode_scalar().

Referenced by attention_forward_decode_head_major_gqa_flash(), ck_attention_flash_decode_wrapper(), mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().

◆ attention_flash_decode_scalar()

static void attention_flash_decode_scalar ( float *  out,
const float *  q,
const float *  k,
const float *  v,
int  T_q,
int  T_k,
int  H,
int  D_h,
float  scale 
)
static

Scalar flash-style attention (online softmax)

Definition at line 142 of file attention_flash_true.c.

152 {
153  const int total = T_q * H;
154  const size_t stride = (size_t)H * (size_t)D_h;
155  const int tile_k = ck_flash_attn_tile_k(D_h);
156 
157  for (int idx = 0; idx < total; ++idx) {
158  const int t_q = idx / H;
159  const int h = idx - t_q * H;
160  const int max_k = max_k_for_query(t_q, T_q, T_k);
161 
162  const float *q_head = q + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
163  float *out_head = out + (size_t)t_q * stride + (size_t)h * (size_t)D_h;
164  const float *k_base = k + (size_t)h * (size_t)D_h;
165  const float *v_base = v + (size_t)h * (size_t)D_h;
166 
167  for (int d = 0; d < D_h; ++d) {
168  out_head[d] = 0.0f;
169  }
170 
171  float m = -INFINITY;
172  float s = 0.0f;
173 
174  float scores[CK_FLASH_ATTN_TILE_K];
175 
176  for (int t_k0 = 0; t_k0 <= max_k; t_k0 += tile_k) {
177  int blk_len = max_k - t_k0 + 1;
178  if (blk_len > tile_k) {
179  blk_len = tile_k;
180  }
181 
182  float m_block = -INFINITY;
183  for (int bi = 0; bi < blk_len; ++bi) {
184  const int t_k = t_k0 + bi;
185  const float *k_head = k_base + (size_t)t_k * stride;
186 
187  float dot = 0.0f;
188  for (int d = 0; d < D_h; ++d) {
189  dot += q_head[d] * k_head[d];
190  }
191 
192  float score = dot * scale;
193  scores[bi] = score;
194  if (score > m_block) {
195  m_block = score;
196  }
197  }
198 
199  if (m_block > m) {
200  float scale_old = (m == -INFINITY) ? 0.0f : ck_expf(m - m_block);
201  s *= scale_old;
202  for (int d = 0; d < D_h; ++d) {
203  out_head[d] *= scale_old;
204  }
205  m = m_block;
206  }
207 
208  for (int bi = 0; bi < blk_len; ++bi) {
209  const int t_k = t_k0 + bi;
210  const float *v_head = v_base + (size_t)t_k * stride;
211  float w = ck_expf(scores[bi] - m);
212  s += w;
213  for (int d = 0; d < D_h; ++d) {
214  out_head[d] += w * v_head[d];
215  }
216  }
217  }
218 
219  if (s > 0.0f) {
220  float inv_s = 1.0f / s;
221  for (int d = 0; d < D_h; ++d) {
222  out_head[d] *= inv_s;
223  }
224  } else {
225  for (int d = 0; d < D_h; ++d) {
226  out_head[d] = 0.0f;
227  }
228  }
229  }
230 }
static int max_k_for_query(int t_q, int T_q, int T_k)
static int ck_flash_attn_tile_k(int D_h)
static float ck_expf(float x)
#define CK_FLASH_ATTN_TILE_K
int32_t float * score
Definition: tokenizer.h:327

References ck_expf(), CK_FLASH_ATTN_TILE_K, ck_flash_attn_tile_k(), max_k_for_query(), and score.

Referenced by attention_flash_decode().

◆ attention_flash_init()

void attention_flash_init ( int  max_context,
int  max_heads,
int  max_head_dim 
)

Initialize flash attention buffers.

Definition at line 731 of file attention_flash_true.c.

731  {
732  // For future optimization: pre-allocate scratch buffers
733  // Currently using stack/heap allocation
734 }

◆ ck_expf()

static float ck_expf ( float  x)
inlinestatic

Definition at line 80 of file attention_flash_true.c.

80  {
81 #if CK_FLASH_ATTN_FAST_EXP
82  return ck_fast_expf(x);
83 #else
84  return expf(x);
85 #endif
86 }
static float ck_fast_expf(float x)

References ck_fast_expf().

Referenced by attention_flash_decode_scalar().

◆ ck_fast_expf()

static float ck_fast_expf ( float  x)
inlinestatic

Definition at line 47 of file attention_flash_true.c.

47  {
48  const float max_val = 88.0f;
49  const float min_val = -88.0f;
50  if (x > max_val) {
51  x = max_val;
52  } else if (x < min_val) {
53  x = min_val;
54  }
55 
56  const float log2e = 1.4426950408889634f;
57  float z = x * log2e;
58  float zf = nearbyintf(z);
59  float f = z - zf;
60 
61  const float c0 = 1.0f;
62  const float c1 = 0.6931471805599453f;
63  const float c2 = 0.2402265069591007f;
64  const float c3 = 0.05550410866482158f;
65  const float c4 = 0.009618129107628478f;
66 
67  float poly = ((c4 * f + c3) * f + c2) * f + c1;
68  poly = poly * f + c0;
69 
70  int32_t zi = (int32_t)zf + 127;
71  uint32_t bits = (uint32_t)zi << 23;
72  union {
73  uint32_t i;
74  float f;
75  } u;
76  u.i = bits;
77  return poly * u.f;
78 }

Referenced by ck_expf().

◆ ck_flash_attn_choose_tile_k()

int ck_flash_attn_choose_tile_k ( int  D_h)

Definition at line 108 of file attention_flash_true.c.

108  {
109  return ck_flash_attn_tile_k(D_h);
110 }

References ck_flash_attn_tile_k().

◆ ck_flash_attn_fast_exp_kind()

int ck_flash_attn_fast_exp_kind ( void  )

Definition at line 112 of file attention_flash_true.c.

112  {
113 #if CK_FLASH_ATTN_FAST_EXP
114 #if defined(__AVX512F__)
115  return 512;
116 #elif defined(__AVX__)
117  return 256;
118 #else
119  return 0;
120 #endif
121 #else
122  return 0;
123 #endif
124 }

◆ ck_flash_attn_tile_k()

static int ck_flash_attn_tile_k ( int  D_h)
inlinestatic

Definition at line 88 of file attention_flash_true.c.

88  {
89  int tile = CK_FLASH_ATTN_TILE_K;
90  if (D_h > 128) {
91  tile = CK_FLASH_ATTN_TILE_K / 4;
92  } else if (D_h > 64) {
93  tile = CK_FLASH_ATTN_TILE_K / 2;
94  }
95 
96  if (CK_FLASH_ATTN_TILE_K >= 8 && tile < 8) {
97  tile = 8;
98  }
99  if (tile > CK_FLASH_ATTN_TILE_K) {
100  tile = CK_FLASH_ATTN_TILE_K;
101  }
102  if (tile < 1) {
103  tile = 1;
104  }
105  return tile;
106 }

References CK_FLASH_ATTN_TILE_K.

Referenced by attention_flash_decode_scalar(), and ck_flash_attn_choose_tile_k().

◆ max_k_for_query()

static int max_k_for_query ( int  t_q,
int  T_q,
int  T_k 
)
inlinestatic

Definition at line 126 of file attention_flash_true.c.

126  {
127  int q_pos_offset = (T_k > T_q) ? (T_k - T_q) : 0;
128  int max_k = q_pos_offset + t_q;
129  if (max_k >= T_k) {
130  max_k = T_k - 1;
131  }
132  return max_k;
133 }

Referenced by attention_flash_decode_scalar().