← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_prefill.c File Reference

Mega-fused prefill attention kernel. More...

#include "ckernel_engine.h"
#include "ckernel_orchestration.h"
#include "ckernel_quant.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

Go to the source code of this file.

Functions

static size_t align_up_size (size_t value, size_t align)
 
static int ck_q8_0_outproj_enabled (void)
 
static void flatten_head_major (const float *attn_out, float *dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
void mega_fused_attention_prefill (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
 Mega-fused attention for prefill mode (multiple tokens) More...
 
size_t mega_fused_attention_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 Get scratch buffer size for mega_fused_attention_prefill. More...
 
static void out_proj_head_major_q5_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void quantize_attn_out_head_major_q8_0 (const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
 

Detailed Description

Mega-fused prefill attention kernel.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. NO memcpy for layout - use strided access, not copies
  4. API must define: inputs, outputs, workspace, and memory layouts
  5. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual Writes K/V directly into the KV cache layout (stride = cache_capacity).

PERFORMANCE OPTIMIZATION:

Uses ck_gemm_nt_head_major_*() to read head-major attention output directly with strided access, eliminating the flatten_head_major() memcpy bottleneck (448 memcpy calls for 32 tokens × 14 heads)

.* TESTING

python3 scripts/bench_mega_fused_attention_prefill.py –q8-outproj –seq-lens 32,64 –iters 3 –warmup 1

Definition in file mega_fused_attention_prefill.c.

Function Documentation

◆ align_up_size()

static size_t align_up_size ( size_t  value,
size_t  align 
)
static

Definition at line 39 of file mega_fused_attention_prefill.c.

39  {
40  return (value + align - 1) & ~(align - 1);
41 }

Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_scratch_size().

◆ ck_q8_0_outproj_enabled()

static int ck_q8_0_outproj_enabled ( void  )
static

Definition at line 63 of file mega_fused_attention_prefill.c.

64 {
65  static int cached = -2;
66  if (cached != -2) {
67  return cached;
68  }
69 
70  const char *env = getenv("CK_Q8_0_OUTPROJ");
71  if (!env || !env[0]) {
72  cached = 0;
73  return cached;
74  }
75  if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' ||
76  env[0] == 'f' || env[0] == 'F') {
77  cached = 0;
78  } else {
79  cached = 1;
80  }
81  return cached;
82 }

Referenced by mega_fused_attention_prefill().

◆ flatten_head_major()

static void flatten_head_major ( const float *  attn_out,
float *  dst,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 43 of file mega_fused_attention_prefill.c.

49 {
50  const size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
51  for (int t = 0; t < tokens; ++t) {
52  float *out_row = dst + (size_t)t * (size_t)aligned_embed_dim;
53  for (int h = 0; h < num_heads; ++h) {
54  const float *src = attn_out + (size_t)h * head_in_stride +
55  (size_t)t * (size_t)aligned_head_dim;
56  memcpy(out_row + (size_t)h * (size_t)aligned_head_dim,
57  src,
58  (size_t)aligned_head_dim * sizeof(float));
59  }
60  }
61 }

Referenced by mega_fused_attention_prefill().

◆ mega_fused_attention_prefill()

void mega_fused_attention_prefill ( float *  output,
const float *  input,
const float *  residual,
const float *  ln1_gamma,
const void *  wq,
const float *  bq,
CKDataType  wq_dt,
const void *  wk,
const float *  bk,
CKDataType  wk_dt,
const void *  wv,
const float *  bv,
CKDataType  wv_dt,
const void *  wo,
const float *  bo,
CKDataType  wo_dt,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  start_pos,
int  tokens,
int  cache_capacity,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
float  eps,
void *  scratch 
)

Mega-fused attention for prefill mode (multiple tokens)

Parameters
outputOutput [tokens, aligned_embed_dim] (includes residual add)
inputInput [tokens, aligned_embed_dim]
residualResidual input [tokens, aligned_embed_dim] (or NULL)
ln1_gammaRMSNorm gamma [embed_dim]
wqQ weights [num_heads * aligned_head_dim * aligned_embed_dim]
bqQ bias [num_heads * aligned_head_dim] (or NULL)
wkK weights [num_kv_heads * aligned_head_dim * aligned_embed_dim]
bkK bias [num_kv_heads * aligned_head_dim] (or NULL)
wvV weights [num_kv_heads * aligned_head_dim * aligned_embed_dim]
bvV bias [num_kv_heads * aligned_head_dim] (or NULL)
woOutput projection weights [num_heads * aligned_embed_dim * aligned_head_dim]
boOutput bias [aligned_embed_dim] (or NULL)
kv_cache_kKV cache for K [num_kv_heads * cache_capacity * aligned_head_dim]
kv_cache_vKV cache for V [num_kv_heads * cache_capacity * aligned_head_dim]
rope_cosRoPE cos [max_seq, head_dim/2]
rope_sinRoPE sin [max_seq, head_dim/2]
start_posStarting position in KV cache
tokensNumber of tokens to process
cache_capacityKV cache capacity (stride in tokens)
embed_dimModel hidden dimension (unpadded)
aligned_embed_dimAligned hidden dimension
num_headsNumber of attention heads
num_kv_headsNumber of KV heads
head_dimHead dimension (unpadded)
aligned_head_dimAligned head dimension
epsRMSNorm epsilon

Definition at line 160 of file mega_fused_attention_prefill.c.

184 {
185  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
186  !kv_cache_k || !kv_cache_v || !scratch) {
187  return;
188  }
189  if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
190  head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
191  return;
192  }
193  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
194  return;
195  }
196  if (start_pos < 0 || start_pos + tokens > cache_capacity) {
197  return;
198  }
199 
200  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
201  (size_t)aligned_head_dim * sizeof(float);
202  const size_t attn_bytes = q_bytes;
203  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
204  const size_t qkv_scratch_bytes = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
205 
206  uint8_t *scratch_bytes = (uint8_t *)scratch;
207  float *q = (float *)scratch_bytes;
208  scratch_bytes += align_up_size(q_bytes, 64);
209  float *attn_out = (float *)scratch_bytes;
210  scratch_bytes += align_up_size(attn_bytes, 64);
211  float *proj_scratch = (float *)scratch_bytes;
212  scratch_bytes += align_up_size(proj_bytes, 64);
213  void *qkv_scratch = (void *)scratch_bytes;
214  (void)qkv_scratch_bytes;
215 
216  float *k_ptr = kv_cache_k + (size_t)start_pos * (size_t)aligned_head_dim;
217  float *v_ptr = kv_cache_v + (size_t)start_pos * (size_t)aligned_head_dim;
218 
219  if (wq_dt == CK_DT_FP32 && wk_dt == CK_DT_FP32 && wv_dt == CK_DT_FP32) {
221  ln1_gamma,
222  (const float *)wq, bq,
223  (const float *)wk, bk,
224  (const float *)wv, bv,
225  q,
226  k_ptr,
227  v_ptr,
228  tokens,
229  embed_dim,
230  aligned_embed_dim,
231  num_heads,
232  num_kv_heads,
233  head_dim,
234  aligned_head_dim,
235  cache_capacity,
236  eps,
237  qkv_scratch);
238  } else {
240  ln1_gamma,
241  wq, bq, wq_dt,
242  wk, bk, wk_dt,
243  wv, bv, wv_dt,
244  q,
245  k_ptr,
246  v_ptr,
247  tokens,
248  embed_dim,
249  aligned_embed_dim,
250  num_heads,
251  num_kv_heads,
252  head_dim,
253  aligned_head_dim,
254  cache_capacity,
255  eps,
256  qkv_scratch);
257  }
258 
259  if (rope_cos && rope_sin) {
261  k_ptr,
262  rope_cos,
263  rope_sin,
264  num_heads,
265  num_kv_heads,
266  tokens,
267  head_dim,
268  aligned_head_dim,
269  start_pos,
270  tokens,
271  cache_capacity);
272  }
273 
274  if (start_pos == 0) {
276  k_ptr,
277  v_ptr,
278  attn_out,
279  num_heads,
280  num_kv_heads,
281  tokens,
282  head_dim,
283  aligned_head_dim,
284  cache_capacity);
285  } else {
286  const float scale = 1.0f / sqrtf((float)head_dim);
287  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
288  const size_t kv_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
289 
290  for (int h = 0; h < num_heads; ++h) {
291  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
292  const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
293  const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
294 
295  for (int i = 0; i < tokens; ++i) {
296  const float *q_vec = q + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
297  float *out_vec = attn_out + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
298  attention_flash_decode(out_vec,
299  q_vec,
300  k_head,
301  v_head,
302  1,
303  start_pos + i + 1,
304  1,
305  aligned_head_dim,
306  scale);
307  }
308  }
309  }
310 
311  if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
312  return;
313  }
314 
315  if (wo_dt == CK_DT_Q5_0 &&
317  (aligned_head_dim % QK5_0) == 0 &&
318  (aligned_embed_dim % QK5_0) == 0) {
319  /* Quantized activations path: Q8_0 attn_out + Q5_0 weights. */
320  uint8_t *attn_q8 = (uint8_t *)q;
322  attn_q8,
323  tokens,
324  num_heads,
325  aligned_head_dim);
327  wo,
328  bo,
329  output,
330  tokens,
331  aligned_embed_dim,
332  num_heads,
333  aligned_head_dim);
334  } else if (wo_dt == CK_DT_Q5_0 &&
335  (aligned_head_dim % QK5_0) == 0 &&
336  (aligned_embed_dim % QK5_0) == 0) {
337  /* Head-major output projection with Q5_0 weights - no flatten needed */
339  wo,
340  bo,
341  output,
342  tokens,
343  aligned_embed_dim,
344  num_heads,
345  aligned_head_dim);
346  } else if (wo_dt == CK_DT_Q8_0 &&
347  (aligned_head_dim % QK8_0) == 0 &&
348  (aligned_embed_dim % QK8_0) == 0) {
349  /* Head-major output projection with Q8_0 weights - no flatten needed */
351  wo,
352  bo,
353  output,
354  tokens,
355  aligned_embed_dim,
356  num_heads,
357  aligned_head_dim);
358  } else {
359  /* Fallback: flatten then GEMM (slow path) */
360  flatten_head_major(attn_out,
361  proj_scratch,
362  tokens,
363  aligned_embed_dim,
364  num_heads,
365  aligned_head_dim);
366 
367  ck_gemm_nt_quant(proj_scratch,
368  wo,
369  bo,
370  output,
371  tokens,
372  aligned_embed_dim,
373  aligned_embed_dim,
374  wo_dt);
375  }
376 
377  if (residual) {
379  output,
380  output,
381  tokens,
382  aligned_embed_dim);
383  }
384 
385 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
#define QK5_0
Definition: ckernel_quant.h:67
#define QK8_0
static size_t align_up_size(size_t value, size_t align)
static void flatten_head_major(const float *attn_out, float *dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
static int ck_q8_0_outproj_enabled(void)

References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q5_0, CK_DT_Q8_0, ck_gemm_nt_head_major_q5_0(), ck_gemm_nt_head_major_q8_0(), ck_gemm_nt_quant(), ck_q8_0_outproj_enabled(), ck_residual_add_token_major(), flatten_head_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q5_0_q8_0(), QK5_0, QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().

◆ mega_fused_attention_prefill_scratch_size()

size_t mega_fused_attention_prefill_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Get scratch buffer size for mega_fused_attention_prefill.

Definition at line 139 of file mega_fused_attention_prefill.c.

143 {
144  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
145  return 0;
146  }
147 
148  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
149  (size_t)aligned_head_dim * sizeof(float);
150  const size_t attn_bytes = q_bytes;
151  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
152  const size_t qkv_scratch = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
153 
154  return align_up_size(q_bytes, 64) +
155  align_up_size(attn_bytes, 64) +
156  align_up_size(proj_bytes, 64) +
157  align_up_size(qkv_scratch, 64);
158 }

References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().

◆ out_proj_head_major_q5_0_q8_0()

static void out_proj_head_major_q5_0_q8_0 ( const uint8_t *  attn_q8,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 104 of file mega_fused_attention_prefill.c.

112 {
113  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
114  (size_t)aligned_head_dim);
115  const int blocks_per_head = aligned_head_dim / QK5_0;
116  const int blocks_per_row = aligned_embed_dim / QK5_0;
117  const block_q5_0 *weights = (const block_q5_0 *)wo;
118 
119  for (int t = 0; t < tokens; ++t) {
120  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
121  for (int n = 0; n < aligned_embed_dim; ++n) {
122  float sum = bias ? bias[n] : 0.0f;
123  const block_q5_0 *w_row = weights + (size_t)n * (size_t)blocks_per_row;
124 
125  for (int h = 0; h < num_heads; ++h) {
126  const uint8_t *a_row = attn_q8 +
127  ((size_t)h * (size_t)tokens + (size_t)t) *
128  q8_row_bytes;
129  const block_q5_0 *w_head = w_row + (size_t)h * (size_t)blocks_per_head;
130  float partial = 0.0f;
131  vec_dot_q5_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
132  sum += partial;
133  }
134  out_row[n] = sum;
135  }
136  }
137 }
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.

References CK_DT_Q8_0, ck_dtype_row_bytes(), QK5_0, and vec_dot_q5_0_q8_0().

Referenced by mega_fused_attention_prefill().

◆ quantize_attn_out_head_major_q8_0()

static void quantize_attn_out_head_major_q8_0 ( const float *  attn_out,
uint8_t *  dst,
int  tokens,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 84 of file mega_fused_attention_prefill.c.

89 {
90  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
91  (size_t)aligned_head_dim);
92  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
93  for (int h = 0; h < num_heads; ++h) {
94  const float *head = attn_out + (size_t)h * head_stride;
95  for (int t = 0; t < tokens; ++t) {
96  const float *row = head + (size_t)t * (size_t)aligned_head_dim;
97  uint8_t *out = dst + ((size_t)h * (size_t)tokens + (size_t)t) *
98  q8_row_bytes;
99  quantize_row_q8_0(row, out, aligned_head_dim);
100  }
101  }
102 }
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References CK_DT_Q8_0, ck_dtype_row_bytes(), and quantize_row_q8_0().

Referenced by mega_fused_attention_prefill().