← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_prefill.c
Go to the documentation of this file.
1 /**
2  * @file mega_fused_attention_prefill.c
3  * @brief Mega-fused prefill attention kernel
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. NO memcpy for layout - use strided access, not copies
10  * 4. API must define: inputs, outputs, workspace, and memory layouts
11  * 5. Pure computation - deterministic, no side effects
12  *
13  * After changes: make test && make llamacpp-parity-full
14  *
15  * RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual
16  * Writes K/V directly into the KV cache layout (stride = cache_capacity).
17  *
18  * PERFORMANCE OPTIMIZATION:
19  * =========================
20  * Uses ck_gemm_nt_head_major_*() to read head-major attention output
21  * directly with strided access, eliminating the flatten_head_major()
22  * memcpy bottleneck (448 memcpy calls for 32 tokens × 14 heads)
23  *
24 .* TESTING
25  * =======
26  * python3 scripts/bench_mega_fused_attention_prefill.py --q8-outproj --seq-lens 32,64 --iters 3 --warmup 1
27  *
28  */
29 
30 #include "ckernel_engine.h"
31 #include "ckernel_orchestration.h"
32 #include "ckernel_quant.h"
33 
34 #include <math.h>
35 #include <stdlib.h>
36 #include <string.h>
37 #include <stdio.h>
38 
39 static size_t align_up_size(size_t value, size_t align) {
40  return (value + align - 1) & ~(align - 1);
41 }
42 
43 static void flatten_head_major(const float *attn_out,
44  float *dst,
45  int tokens,
46  int aligned_embed_dim,
47  int num_heads,
48  int aligned_head_dim)
49 {
50  const size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
51  for (int t = 0; t < tokens; ++t) {
52  float *out_row = dst + (size_t)t * (size_t)aligned_embed_dim;
53  for (int h = 0; h < num_heads; ++h) {
54  const float *src = attn_out + (size_t)h * head_in_stride +
55  (size_t)t * (size_t)aligned_head_dim;
56  memcpy(out_row + (size_t)h * (size_t)aligned_head_dim,
57  src,
58  (size_t)aligned_head_dim * sizeof(float));
59  }
60  }
61 }
62 
63 static int ck_q8_0_outproj_enabled(void)
64 {
65  static int cached = -2;
66  if (cached != -2) {
67  return cached;
68  }
69 
70  const char *env = getenv("CK_Q8_0_OUTPROJ");
71  if (!env || !env[0]) {
72  cached = 0;
73  return cached;
74  }
75  if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' ||
76  env[0] == 'f' || env[0] == 'F') {
77  cached = 0;
78  } else {
79  cached = 1;
80  }
81  return cached;
82 }
83 
84 static void quantize_attn_out_head_major_q8_0(const float *attn_out,
85  uint8_t *dst,
86  int tokens,
87  int num_heads,
88  int aligned_head_dim)
89 {
90  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
91  (size_t)aligned_head_dim);
92  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
93  for (int h = 0; h < num_heads; ++h) {
94  const float *head = attn_out + (size_t)h * head_stride;
95  for (int t = 0; t < tokens; ++t) {
96  const float *row = head + (size_t)t * (size_t)aligned_head_dim;
97  uint8_t *out = dst + ((size_t)h * (size_t)tokens + (size_t)t) *
98  q8_row_bytes;
99  quantize_row_q8_0(row, out, aligned_head_dim);
100  }
101  }
102 }
103 
104 static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8,
105  const void *wo,
106  const float *bias,
107  float *output,
108  int tokens,
109  int aligned_embed_dim,
110  int num_heads,
111  int aligned_head_dim)
112 {
113  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
114  (size_t)aligned_head_dim);
115  const int blocks_per_head = aligned_head_dim / QK5_0;
116  const int blocks_per_row = aligned_embed_dim / QK5_0;
117  const block_q5_0 *weights = (const block_q5_0 *)wo;
118 
119  for (int t = 0; t < tokens; ++t) {
120  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
121  for (int n = 0; n < aligned_embed_dim; ++n) {
122  float sum = bias ? bias[n] : 0.0f;
123  const block_q5_0 *w_row = weights + (size_t)n * (size_t)blocks_per_row;
124 
125  for (int h = 0; h < num_heads; ++h) {
126  const uint8_t *a_row = attn_q8 +
127  ((size_t)h * (size_t)tokens + (size_t)t) *
128  q8_row_bytes;
129  const block_q5_0 *w_head = w_row + (size_t)h * (size_t)blocks_per_head;
130  float partial = 0.0f;
131  vec_dot_q5_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
132  sum += partial;
133  }
134  out_row[n] = sum;
135  }
136  }
137 }
138 
140  int aligned_embed_dim,
141  int num_heads,
142  int aligned_head_dim)
143 {
144  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
145  return 0;
146  }
147 
148  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
149  (size_t)aligned_head_dim * sizeof(float);
150  const size_t attn_bytes = q_bytes;
151  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
152  const size_t qkv_scratch = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
153 
154  return align_up_size(q_bytes, 64) +
155  align_up_size(attn_bytes, 64) +
156  align_up_size(proj_bytes, 64) +
157  align_up_size(qkv_scratch, 64);
158 }
159 
161  float *output,
162  const float *input,
163  const float *residual,
164  const float *ln1_gamma,
165  const void *wq, const float *bq, CKDataType wq_dt,
166  const void *wk, const float *bk, CKDataType wk_dt,
167  const void *wv, const float *bv, CKDataType wv_dt,
168  const void *wo, const float *bo, CKDataType wo_dt,
169  float *kv_cache_k,
170  float *kv_cache_v,
171  const float *rope_cos,
172  const float *rope_sin,
173  int start_pos,
174  int tokens,
175  int cache_capacity,
176  int embed_dim,
177  int aligned_embed_dim,
178  int num_heads,
179  int num_kv_heads,
180  int head_dim,
181  int aligned_head_dim,
182  float eps,
183  void *scratch)
184 {
185  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
186  !kv_cache_k || !kv_cache_v || !scratch) {
187  return;
188  }
189  if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
190  head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
191  return;
192  }
193  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
194  return;
195  }
196  if (start_pos < 0 || start_pos + tokens > cache_capacity) {
197  return;
198  }
199 
200  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
201  (size_t)aligned_head_dim * sizeof(float);
202  const size_t attn_bytes = q_bytes;
203  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
204  const size_t qkv_scratch_bytes = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
205 
206  uint8_t *scratch_bytes = (uint8_t *)scratch;
207  float *q = (float *)scratch_bytes;
208  scratch_bytes += align_up_size(q_bytes, 64);
209  float *attn_out = (float *)scratch_bytes;
210  scratch_bytes += align_up_size(attn_bytes, 64);
211  float *proj_scratch = (float *)scratch_bytes;
212  scratch_bytes += align_up_size(proj_bytes, 64);
213  void *qkv_scratch = (void *)scratch_bytes;
214  (void)qkv_scratch_bytes;
215 
216  float *k_ptr = kv_cache_k + (size_t)start_pos * (size_t)aligned_head_dim;
217  float *v_ptr = kv_cache_v + (size_t)start_pos * (size_t)aligned_head_dim;
218 
219  if (wq_dt == CK_DT_FP32 && wk_dt == CK_DT_FP32 && wv_dt == CK_DT_FP32) {
221  ln1_gamma,
222  (const float *)wq, bq,
223  (const float *)wk, bk,
224  (const float *)wv, bv,
225  q,
226  k_ptr,
227  v_ptr,
228  tokens,
229  embed_dim,
230  aligned_embed_dim,
231  num_heads,
232  num_kv_heads,
233  head_dim,
234  aligned_head_dim,
235  cache_capacity,
236  eps,
237  qkv_scratch);
238  } else {
240  ln1_gamma,
241  wq, bq, wq_dt,
242  wk, bk, wk_dt,
243  wv, bv, wv_dt,
244  q,
245  k_ptr,
246  v_ptr,
247  tokens,
248  embed_dim,
249  aligned_embed_dim,
250  num_heads,
251  num_kv_heads,
252  head_dim,
253  aligned_head_dim,
254  cache_capacity,
255  eps,
256  qkv_scratch);
257  }
258 
259  if (rope_cos && rope_sin) {
261  k_ptr,
262  rope_cos,
263  rope_sin,
264  num_heads,
265  num_kv_heads,
266  tokens,
267  head_dim,
268  aligned_head_dim,
269  start_pos,
270  tokens,
271  cache_capacity);
272  }
273 
274  if (start_pos == 0) {
276  k_ptr,
277  v_ptr,
278  attn_out,
279  num_heads,
280  num_kv_heads,
281  tokens,
282  head_dim,
283  aligned_head_dim,
284  cache_capacity);
285  } else {
286  const float scale = 1.0f / sqrtf((float)head_dim);
287  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
288  const size_t kv_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
289 
290  for (int h = 0; h < num_heads; ++h) {
291  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
292  const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
293  const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
294 
295  for (int i = 0; i < tokens; ++i) {
296  const float *q_vec = q + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
297  float *out_vec = attn_out + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
298  attention_flash_decode(out_vec,
299  q_vec,
300  k_head,
301  v_head,
302  1,
303  start_pos + i + 1,
304  1,
305  aligned_head_dim,
306  scale);
307  }
308  }
309  }
310 
311  if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
312  return;
313  }
314 
315  if (wo_dt == CK_DT_Q5_0 &&
317  (aligned_head_dim % QK5_0) == 0 &&
318  (aligned_embed_dim % QK5_0) == 0) {
319  /* Quantized activations path: Q8_0 attn_out + Q5_0 weights. */
320  uint8_t *attn_q8 = (uint8_t *)q;
322  attn_q8,
323  tokens,
324  num_heads,
325  aligned_head_dim);
327  wo,
328  bo,
329  output,
330  tokens,
331  aligned_embed_dim,
332  num_heads,
333  aligned_head_dim);
334  } else if (wo_dt == CK_DT_Q5_0 &&
335  (aligned_head_dim % QK5_0) == 0 &&
336  (aligned_embed_dim % QK5_0) == 0) {
337  /* Head-major output projection with Q5_0 weights - no flatten needed */
339  wo,
340  bo,
341  output,
342  tokens,
343  aligned_embed_dim,
344  num_heads,
345  aligned_head_dim);
346  } else if (wo_dt == CK_DT_Q8_0 &&
347  (aligned_head_dim % QK8_0) == 0 &&
348  (aligned_embed_dim % QK8_0) == 0) {
349  /* Head-major output projection with Q8_0 weights - no flatten needed */
351  wo,
352  bo,
353  output,
354  tokens,
355  aligned_embed_dim,
356  num_heads,
357  aligned_head_dim);
358  } else {
359  /* Fallback: flatten then GEMM (slow path) */
360  flatten_head_major(attn_out,
361  proj_scratch,
362  tokens,
363  aligned_embed_dim,
364  num_heads,
365  aligned_head_dim);
366 
367  ck_gemm_nt_quant(proj_scratch,
368  wo,
369  bo,
370  output,
371  tokens,
372  aligned_embed_dim,
373  aligned_embed_dim,
374  wo_dt);
375  }
376 
377  if (residual) {
379  output,
380  output,
381  tokens,
382  aligned_embed_dim);
383  }
384 
385 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
Quantization block structures for weight-only quantization.
#define QK5_0
Definition: ckernel_quant.h:67
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
#define QK8_0
size_t mega_fused_attention_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill.
void mega_fused_attention_prefill(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused attention for prefill mode (multiple tokens)
static size_t align_up_size(size_t value, size_t align)
static void flatten_head_major(const float *attn_out, float *dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
static int ck_q8_0_outproj_enabled(void)