← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_decode_fused.c File Reference

Fused attention decode kernel (legacy v6/v6.5) More...

#include "ckernel_orchestration.h"
#include "ckernel_engine.h"
#include <stddef.h>

Go to the source code of this file.

Functions

void ck_attention_project_head_major_decode_token (const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void ck_attention_project_head_major_decode_token_residual (const float *attn_token, const float *wo, const float *bo, const float *residual_in, float *proj_out, float *residual_out, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static float ck_dot_f32 (const float *a, const float *b, int len)
 
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
static void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl (const CKLayerForwardParams *p, int token_index, int cache_capacity, int fuse_mlp)
 
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_qkv_project_head_major_token (const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 

Detailed Description

Fused attention decode kernel (legacy v6/v6.5)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

LEGACY: This file is from v6/v6.5 and kept for backward compatibility.

Definition in file attention_decode_fused.c.

Function Documentation

◆ ck_attention_project_head_major_decode_token()

void ck_attention_project_head_major_decode_token ( const float *  attn_token,
const float *  wo,
const float *  bo,
float *  out_token,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Definition at line 115 of file attention_decode_fused.c.

123 {
124  const size_t head_in_stride = (size_t)aligned_head_dim;
125  const size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
126 
127 #pragma omp parallel for schedule(static)
128  for (int j = 0; j < embed_dim; ++j) {
129  float sum = bo ? bo[j] : 0.0f;
130  for (int h = 0; h < num_heads; ++h) {
131  const float *head_in = attn_token + (size_t)h * head_in_stride;
132  const float *wo_row = wo + (size_t)h * head_weight_stride + (size_t)j * (size_t)aligned_head_dim;
133  sum += ck_dot_f32(head_in, wo_row, aligned_head_dim);
134  }
135  out_token[j] = sum;
136  }
137 
138  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
139  out_token[j] = 0.0f;
140  }
141 }
static float ck_dot_f32(const float *a, const float *b, int len)

References ck_dot_f32().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().

◆ ck_attention_project_head_major_decode_token_residual()

static void ck_attention_project_head_major_decode_token_residual ( const float *  attn_token,
const float *  wo,
const float *  bo,
const float *  residual_in,
float *  proj_out,
float *  residual_out,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 145 of file attention_decode_fused.c.

155 {
156  const size_t head_in_stride = (size_t)aligned_head_dim;
157  const size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
158 
159 #pragma omp parallel for schedule(static)
160  for (int j = 0; j < embed_dim; ++j) {
161  float sum = bo ? bo[j] : 0.0f;
162  for (int h = 0; h < num_heads; ++h) {
163  const float *head_in = attn_token + (size_t)h * head_in_stride;
164  const float *wo_row = wo + (size_t)h * head_weight_stride + (size_t)j * (size_t)aligned_head_dim;
165  sum += ck_dot_f32(head_in, wo_row, aligned_head_dim);
166  }
167  if (proj_out) {
168  proj_out[j] = sum;
169  }
170  residual_out[j] = sum + residual_in[j];
171  }
172 
173  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
174  if (proj_out) {
175  proj_out[j] = 0.0f;
176  }
177  residual_out[j] = 0.0f;
178  }
179 }

References ck_dot_f32().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_dot_f32()

static float ck_dot_f32 ( const float *  a,
const float *  b,
int  len 
)
inlinestatic

Definition at line 41 of file attention_decode_fused.c.

42 {
43 #if defined(__AVX512F__)
44  __m512 acc = _mm512_setzero_ps();
45  int i = 0;
46  for (; i <= len - 16; i += 16) {
47  __m512 va = _mm512_loadu_ps(a + i);
48  __m512 vb = _mm512_loadu_ps(b + i);
49  acc = _mm512_fmadd_ps(va, vb, acc);
50  }
51  float sum = _mm512_reduce_add_ps(acc);
52  for (; i < len; ++i) {
53  sum += a[i] * b[i];
54  }
55  return sum;
56 #elif defined(__AVX__)
57  __m256 acc = _mm256_setzero_ps();
58  int i = 0;
59  for (; i <= len - 8; i += 8) {
60  __m256 va = _mm256_loadu_ps(a + i);
61  __m256 vb = _mm256_loadu_ps(b + i);
62  acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
63  }
64  float sum = ck_hsum256_ps(acc);
65  for (; i < len; ++i) {
66  sum += a[i] * b[i];
67  }
68  return sum;
69 #else
70  float sum = 0.0f;
71  for (int i = 0; i < len; ++i) {
72  sum += a[i] * b[i];
73  }
74  return sum;
75 #endif
76 }

Referenced by ck_attention_project_head_major_decode_token(), and ck_attention_project_head_major_decode_token_residual().

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused_attn()

void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 343 of file attention_decode_fused.c.

346 {
348  token_index,
349  cache_capacity,
350  /*fuse_mlp=*/0);
351 }
static void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(const CKLayerForwardParams *p, int token_index, int cache_capacity, int fuse_mlp)

References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl()

static void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity,
int  fuse_mlp 
)
static

Definition at line 181 of file attention_decode_fused.c.

185 {
186  if (!p) {
187  return;
188  }
189  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->wq || !p->wk || !p->wv || !p->wo ||
190  !p->w1 || !p->w2 || !p->k || !p->v || !p->residual1 || !p->ln2_out || !p->swiglu_out ||
191  !p->mlp_out || !p->output) {
192  return;
193  }
194  if (!fuse_mlp && !p->fc1_out) {
195  return;
196  }
197  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
198  return;
199  }
200  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
201  return;
202  }
203 
204  const int D = p->embed_dim;
205  const int aligned_D = p->aligned_embed_dim;
206  const int H = p->num_heads;
207  const int H_kv = p->num_kv_heads;
208  const int hd = p->head_dim;
209  const int ad = p->aligned_head_dim;
210  const int aligned_intermediate = p->aligned_intermediate_dim;
211 
212  /* Decode buffers are single-token; token_index only applies to KV cache. */
213  const size_t token_slot = 0;
214  const float *input_row = p->input + token_slot * (size_t)aligned_D;
215  float *proj_row = NULL;
216  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
217  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
218  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
219  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
220  float *out_row = p->output + token_slot * (size_t)aligned_D;
221 
222  float ln1_rstd_tmp = 0.0f;
223  float ln2_rstd_tmp = 0.0f;
224  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
225 
226  float ln1_row[aligned_D];
227 
228  rmsnorm_forward(input_row,
229  p->ln1_gamma,
230  ln1_row,
231  &ln1_rstd_tmp,
232  /*tokens=*/1,
233  D,
234  aligned_D,
235  p->eps);
236 
237  size_t q_elems = (size_t)H * (size_t)ad;
238  size_t kv_elems = (size_t)H_kv * (size_t)ad;
239  float q_token[q_elems];
240  float k_token[kv_elems];
241  float v_token[kv_elems];
242  float attn_token[q_elems];
243 
245  p->wq, p->bq,
246  p->wk, p->bk,
247  p->wv, p->bv,
248  q_token, k_token, v_token,
249  aligned_D,
250  H,
251  H_kv,
252  ad);
253 
254  if (p->rope_cos && p->rope_sin) {
255  rope_forward_qk(q_token,
256  k_token,
257  p->rope_cos,
258  p->rope_sin,
259  H,
260  H_kv,
261  /*num_tokens=*/1,
262  hd,
263  ad,
264  p->rope_pos_offset);
265  }
266 
268  v_token,
269  p->k,
270  p->v,
271  H_kv,
272  token_index,
273  cache_capacity,
274  hd,
275  ad);
276 
278  p->k,
279  p->v,
280  attn_token,
281  H,
282  H_kv,
283  /*kv_tokens=*/token_index + 1,
284  cache_capacity,
285  hd,
286  ad);
287 
289  p->wo,
290  p->bo,
291  input_row,
292  proj_row,
293  residual_row,
294  D,
295  aligned_D,
296  H,
297  ad);
298 
299  rmsnorm_forward(residual_row,
300  p->ln2_gamma,
301  ln2_row,
302  ln2_rstd,
303  /*tokens=*/1,
304  D,
305  aligned_D,
306  p->eps);
307 
308  if (fuse_mlp) {
309  // Fully fused MLP avoids writing SwiGLU activations to DRAM (aligned dims
310  // match padded weight layout).
312  p->w1,
313  p->b1,
314  p->w2,
315  p->b2,
316  mlp_row,
317  aligned_D,
318  aligned_intermediate);
319  } else {
320  int up_dim = 2 * aligned_intermediate;
321  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
322 
323  ck_mlp_swiglu_forward(ln2_row,
324  p->w1,
325  p->b1,
326  p->w2,
327  p->b2,
328  fc1_row,
329  swiglu_row,
330  mlp_row,
331  /*tokens=*/1,
332  aligned_D,
333  aligned_intermediate);
334  }
335 
336  ck_residual_add_token_major(residual_row,
337  mlp_row,
338  out_row,
339  /*tokens=*/1,
340  aligned_D);
341 }
static void ck_attention_project_head_major_decode_token_residual(const float *attn_token, const float *wo, const float *bo, const float *residual_in, float *proj_out, float *residual_out, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)

References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_decode_head_major_gqa_regular(), CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major_decode_token_residual(), ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_fully_fused_token(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp().

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp()

void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 353 of file attention_decode_fused.c.

356 {
358  token_index,
359  cache_capacity,
360  /*fuse_mlp=*/1);
361 }

References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_qkv_project_head_major_token()

void ck_qkv_project_head_major_token ( const float *  input_row,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  q_token,
float *  k_token,
float *  v_token,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)

Definition at line 78 of file attention_decode_fused.c.

89 {
90  if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
91  return;
92  }
93 
94  const int q_out = num_heads * aligned_head_dim;
95  gemm_blocked_serial(input_row, wq, bq, q_token,
96  /*tokens=*/1, q_out, aligned_embed_dim);
97 
98  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
99 #pragma omp parallel for schedule(static) if(num_kv_heads > 1)
100  for (int h = 0; h < num_kv_heads; ++h) {
101  const float *wk_h = wk + (size_t)h * head_weight_stride;
102  const float *wv_h = wv + (size_t)h * head_weight_stride;
103  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
104  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
105  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
106  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
107 
108  gemm_blocked_serial(input_row, wk_h, bk_h, k_h,
109  /*tokens=*/1, aligned_head_dim, aligned_embed_dim);
110  gemm_blocked_serial(input_row, wv_h, bv_h, v_h,
111  /*tokens=*/1, aligned_head_dim, aligned_embed_dim);
112  }
113 }
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661

References gemm_blocked_serial().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().