← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_decode_q5_0.c
Go to the documentation of this file.
1 /**
2  * @file mega_fused_attention_decode_q5_0.c
3  * @brief Mega-fused attention decode with Q5_0 weights
4  *
5  * STATUS: Serial kernel complete and correct. Parallel variant is a prototype
6  * that requires threadpool barrier support (not yet available).
7  * The non-fused decode path (ck_parallel_decode.h) already parallelizes
8  * each GEMV via row-splitting, so this fused kernel is not on the
9  * critical path. It can be enabled once threadpool parallelization
10  * is resolved (see PARALLELIZATION NOTES below).
11  *
12  * FUSION: Combines 9 operations to minimize memory traffic.
13  * All intermediate data stays in scratch buffer (L1/L2 cache).
14  *
15  * Operations fused:
16  * 1. RMSNorm
17  * 2. Q projection (Q5_0) with bias
18  * 3. K projection (Q5_0) with bias
19  * 4. V projection (Q8_0) with bias
20  * 5. RoPE application
21  * 6. KV cache store
22  * 7. Flash attention decode (GQA-aware)
23  * 8. O projection (Q5_0) with bias
24  * 9. Residual add
25  *
26  * PARALLELIZATION NOTES:
27  * The parallel_simd variant below documents the intended threading model
28  * but cannot run with the current threadpool (single dispatch, no mid-dispatch
29  * barrier). Three approaches were evaluated:
30  *
31  * (A) Multi-dispatch (RECOMMENDED):
32  * Break into 3 ck_threadpool_dispatch() calls per layer:
33  * Dispatch 1: Row-split Q proj across threads.
34  * Thread 0 also does RMSNorm, K/V proj, RoPE, KV store
35  * (small ops that fit within Q proj wall time).
36  * Dispatch 2: Split attention across heads (h_start..h_end per thread).
37  * Dispatch 3: Row-split O proj across threads.
38  * Thread 0 does residual add after its rows.
39  * Cost: ~1us total for 2 extra barrier round-trips (negligible vs ~100us GEMV).
40  * Intermediates stay in shared scratch — cache benefit preserved.
41  *
42  * (B) Redundant compute (single dispatch, no barrier):
43  * All threads redundantly compute RMSNorm + K/V proj + RoPE (~4us wasted
44  * per thread). Avoids barrier but wastes cycles on small ops.
45  * Only viable if Q/O proj dominate (true for short contexts).
46  *
47  * (C) Skip fusion, use existing parallel GEMV:
48  * The non-fused decode path already parallelizes each GEMV call via
49  * ck_parallel_decode.h. For decode (M=1), intermediates are small
50  * (~3.5KB), so DRAM bandwidth savings from fusion are minimal.
51  * This is the current production path.
52  *
53  * TESTING:
54  * make test-mega-fused-parity # Numerical parity
55  * make test-mega-fused-speed # Performance benchmark
56  */
57 
58 #include <stdio.h>
59 #include <stdlib.h>
60 #include <string.h>
61 #include <stdint.h>
62 #include <math.h>
63 
64 #include "ckernel_quant.h"
65 
66 /* Declare functions from other kernel files */
67 extern void rmsnorm_forward(const float *input, const float *gamma, float *output,
68  float *rstd, int T, int D, int AD, float eps);
70  const float *q_token, const float *k_cache, const float *v_cache,
71  float *out_token, int num_heads, int num_kv_heads, int kv_tokens,
72  int cache_capacity, int head_dim, int aligned_head_dim);
73 extern void quantize_row_q8_0(const float *x, void *vy, int k);
74 extern void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy);
75 extern void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy);
76 
77 /* ============================================================================
78  * Q5_0 GEMV with inline processing
79  *
80  * For true fusion: quantize input to Q8_0, then use efficient vec_dot
81  * Uses scratch buffer instead of malloc (kernel rule)
82  * ============================================================================ */
83 
84 static inline void gemv_q5_0_from_fp32(
85  float *out, /* Output [M] */
86  const void *W_q5_0, /* Q5_0 weights [M, K] */
87  const float *x_fp32, /* FP32 input [K] */
88  const float *bias, /* Bias [M] or NULL */
89  int M,
90  int K,
91  block_q8_0 *x_q8_scratch) /* Scratch buffer for quantized input */
92 {
93  const block_q5_0 *w_blocks = (const block_q5_0 *)W_q5_0;
94  const int blocks_per_row = K / QK5_0;
95 
96  /* Quantize input to Q8_0 (reuse existing kernel, scratch buffer) */
97  quantize_row_q8_0(x_fp32, x_q8_scratch, K);
98 
99  /* Compute dot products using optimized kernel */
100  for (int row = 0; row < M; row++) {
101  float dot;
102  vec_dot_q5_0_q8_0(K, &dot, &w_blocks[row * blocks_per_row], x_q8_scratch);
103  out[row] = dot + (bias ? bias[row] : 0.0f);
104  }
105 }
106 
107 /* Q8_0 GEMV - input is already FP32, weights are Q8_0 */
108 static inline void gemv_q8_0_from_fp32(
109  float *out,
110  const void *W_q8_0,
111  const float *x_fp32,
112  const float *bias,
113  int M,
114  int K,
115  block_q8_0 *x_q8_scratch)
116 {
117  const block_q8_0 *w_blocks = (const block_q8_0 *)W_q8_0;
118  const int blocks_per_row = K / QK8_0;
119 
120  /* Quantize input to Q8_0 (reuse existing kernel, scratch buffer) */
121  quantize_row_q8_0(x_fp32, x_q8_scratch, K);
122 
123  /* Compute dot products */
124  for (int row = 0; row < M; row++) {
125  float dot;
126  vec_dot_q8_0_q8_0(K, &dot, &w_blocks[row * blocks_per_row], x_q8_scratch);
127  out[row] = dot + (bias ? bias[row] : 0.0f);
128  }
129 }
130 
131 /* ============================================================================
132  * RoPE application (inline)
133  * ============================================================================ */
134 
135 static inline void apply_rope_inline(
136  float *q,
137  float *k,
138  const float *rope_cos,
139  const float *rope_sin,
140  int pos,
141  int H,
142  int KV,
143  int AD)
144 {
145  const int D = AD / 2;
146  const float *cos_row = &rope_cos[pos * D];
147  const float *sin_row = &rope_sin[pos * D];
148 
149  /* Q heads */
150  for (int h = 0; h < H; h++) {
151  float *q_head = &q[h * AD];
152  for (int d = 0; d < D; d++) {
153  float q0 = q_head[d];
154  float q1 = q_head[d + D];
155  q_head[d] = q0 * cos_row[d] - q1 * sin_row[d];
156  q_head[d + D] = q0 * sin_row[d] + q1 * cos_row[d];
157  }
158  }
159 
160  /* K heads */
161  for (int kv = 0; kv < KV; kv++) {
162  float *k_head = &k[kv * AD];
163  for (int d = 0; d < D; d++) {
164  float k0 = k_head[d];
165  float k1 = k_head[d + D];
166  k_head[d] = k0 * cos_row[d] - k1 * sin_row[d];
167  k_head[d + D] = k0 * sin_row[d] + k1 * cos_row[d];
168  }
169  }
170 }
171 
172 /* ============================================================================
173  * Calculate scratch size needed (with all required parameters)
174  * ============================================================================ */
175 
176 int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD) {
177  /* Need: 1x AE for RMSNorm output
178  1x AE for RMSNorm rstd (avoid VLA)
179  1x H*AD for Q
180  1x KV*AD for K
181  1x KV*AD for V
182  1x H*AD for attention output
183  1x max(AE, H*AD)/QK8_0 * sizeof(block_q8_0) for GEMV scratch
184  */
185  int max_input_dim = (AE > H * AD) ? AE : H * AD;
186  int q8_blocks = (max_input_dim + QK8_0 - 1) / QK8_0;
187  return (int)(sizeof(float) * (AE + AE + H * AD + 2 * KV * AD + H * AD)
188  + q8_blocks * sizeof(block_q8_0));
189 }
190 
191 /* ============================================================================
192  * MAIN KERNEL
193  *
194  * @param output Output [AE] (final result, after residual add)
195  * @param input Input activation [AE]
196  * @param residual Residual input for add [AE]
197  * @param wq_q5_0 Q projection weights [H*AD, AE] Q5_0
198  * @param wk_q5_0 K projection weights [KV*AD, AE] Q5_0
199  * @param wv_q8_0 V projection weights [KV*AD, AE] Q8_0
200  * @param wo_q5_0 O projection weights [AE, H*AD] Q5_0 (row e has H*AD elements)
201  * @param ln_gamma RMSNorm gamma [AE]
202  * @param bq Q bias [H*AD] or NULL
203  * @param bk K bias [KV*AD] or NULL
204  * @param bv V bias [KV*AD] or NULL
205  * @param bo O bias [AE] or NULL
206  * @param kv_cache_k K cache [KV, max_T, AD]
207  * @param kv_cache_v V cache [KV, max_T, AD]
208  * @param rope_cos RoPE cos [max_T, D]
209  * @param rope_sin RoPE sin [max_T, D]
210  * @param pos Current position
211  * @param embed_dim Original embed dim E
212  * @param aligned_embed_dim Aligned embed dim AE (multiple of 64)
213  * @param num_heads H
214  * @param num_kv_heads KV
215  * @param head_dim AD
216  * @param aligned_head_dim AAD (multiple of 64)
217  * @param cache_capacity max_T
218  * @param eps RMSNorm epsilon
219  * @param scratch Scratch buffer (must be >= scratch_size bytes)
220  * ============================================================================ */
221 
223  float *output,
224  const float *input,
225  const float *residual,
226  const void *wq_q5_0,
227  const void *wk_q5_0,
228  const void *wv_q8_0,
229  const void *wo_q5_0,
230  const float *ln_gamma,
231  const float *bq,
232  const float *bk,
233  const float *bv,
234  const float *bo,
235  float *kv_cache_k,
236  float *kv_cache_v,
237  const float *rope_cos,
238  const float *rope_sin,
239  int pos,
240  int embed_dim,
241  int aligned_embed_dim,
242  int num_heads,
243  int num_kv_heads,
244  int head_dim,
245  int aligned_head_dim,
246  int cache_capacity,
247  float eps,
248  void *scratch)
249 {
250  const int H = num_heads;
251  const int KV = num_kv_heads;
252  const int AD = head_dim;
253  const int AE = aligned_embed_dim;
254  (void)embed_dim; /* Unused but kept for API consistency */
255 
256  /* Parse scratch buffer - all allocations from scratch, no VLAs */
257  float *scratch_ptr = (float *)scratch;
258 
259  float *rmsnorm_out = scratch_ptr;
260  scratch_ptr += AE;
261 
262  float *rstd_scratch = scratch_ptr; /* For rmsnorm rstd output - avoids VLA */
263  scratch_ptr += AE;
264 
265  float *q = scratch_ptr;
266  scratch_ptr += H * AD;
267 
268  float *k = scratch_ptr;
269  scratch_ptr += KV * AD;
270 
271  float *v = scratch_ptr;
272  scratch_ptr += KV * AD;
273 
274  float *attn_out = scratch_ptr;
275  scratch_ptr += H * AD;
276 
277  block_q8_0 *x_q8_scratch = (block_q8_0 *)scratch_ptr;
278 
279  const int q_size = H * AD;
280  const int k_size = KV * AD;
281  const int v_size = KV * AD;
282 
283  /* ========================================================================
284  * STEP 1: RMSNorm
285  * Correct signature: rmsnorm_forward(in, gamma, out, rstd, T, D, AD, eps)
286  * T=1 (single token), D=AE (full embed dim for norm)
287  * ======================================================================== */
288  rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
289 
290  /* ========================================================================
291  * STEP 2-4: Q, K, V projections (fused with quantization)
292  * Use scratch buffer for quantized input
293  * ======================================================================== */
294  gemv_q5_0_from_fp32(q, wq_q5_0, rmsnorm_out, bq, q_size, AE, x_q8_scratch);
295  gemv_q5_0_from_fp32(k, wk_q5_0, rmsnorm_out, bk, k_size, AE, x_q8_scratch);
296  gemv_q8_0_from_fp32(v, wv_q8_0, rmsnorm_out, bv, v_size, AE, x_q8_scratch);
297 
298  /* ========================================================================
299  * STEP 5: Apply RoPE
300  * ======================================================================== */
301  apply_rope_inline(q, k, rope_cos, rope_sin, pos, H, KV, AD);
302 
303  /* ========================================================================
304  * STEP 6: Store K and V to cache
305  * Cache layout: [KV, cache_capacity, AD]
306  * ======================================================================== */
307  const size_t kv_stride = (size_t)cache_capacity * AD;
308  for (int kv = 0; kv < KV; kv++) {
309  float *k_cache = &kv_cache_k[kv * kv_stride];
310  float *v_cache = &kv_cache_v[kv * kv_stride];
311  const float *k_src = &k[kv * AD];
312  const float *v_src = &v[kv * AD];
313  const int offset = pos * AD;
314  for (int d = 0; d < AD; d++) {
315  k_cache[offset + d] = k_src[d];
316  v_cache[offset + d] = v_src[d];
317  }
318  }
319 
320  /* ========================================================================
321  * STEP 7: Flash attention decode (GQA-aware variant)
322  * attention_forward_decode_head_major_gqa_flash handles H != KV correctly
323  * It maps each of H heads to one of KV KV heads via: kv_head = h * KV / H
324  * ======================================================================== */
326  q, kv_cache_k, kv_cache_v,
327  attn_out, H, KV, pos + 1, cache_capacity, AD, aligned_head_dim);
328 
329  /* ========================================================================
330  * STEP 8: O projection (Q5_0 weights) with bias and residual add
331  *
332  * attn_out layout: [H * AD] flattened
333  * wo_q5_0 layout: [AE, H*AD] - row e has H*AD input features
334  *
335  * O projection: output[e] = dot(wo[e], attn_out) + bias[e] + residual[e]
336  *
337  * Using vec_dot_q5_0_q8_0 for efficient quantized dot product.
338  * ======================================================================== */
339 
340  /* Quantize attention output to Q8_0 for GEMV */
341  quantize_row_q8_0(attn_out, x_q8_scratch, H * AD);
342 
343  const block_q5_0 *wo = (const block_q5_0 *)wo_q5_0;
344  const int blocks_per_row = (H * AD) / QK5_0;
345 
346  for (int e = 0; e < AE; e++) {
347  float dot;
348  vec_dot_q5_0_q8_0(H * AD, &dot, &wo[e * blocks_per_row], x_q8_scratch);
349  output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
350  }
351 }
352 
353 /* ============================================================================
354  * PARALLEL SIMD VARIANT (threadpool-aware)
355  *
356  * Parallelizes across attention heads using (ith, nth) pattern.
357  * Each thread processes a subset of heads.
358  *
359  * IMPORTANT: Caller must ensure barrier sync between phases:
360  * Phase 1 (ith==0 only): RMSNorm, Q/K/V projection, RoPE, KV cache store
361  * -- BARRIER --
362  * Phase 2 (all threads): Attention for assigned heads
363  * -- BARRIER --
364  * Phase 3 (ith==0 only): O projection and residual add
365  * ======================================================================== */
366 
368  float *output,
369  const float *input,
370  const float *residual,
371  const void *wq_q5_0,
372  const void *wk_q5_0,
373  const void *wv_q8_0,
374  const void *wo_q5_0,
375  const float *ln_gamma,
376  const float *bq,
377  const float *bk,
378  const float *bv,
379  const float *bo,
380  float *kv_cache_k,
381  float *kv_cache_v,
382  const float *rope_cos,
383  const float *rope_sin,
384  int pos,
385  int embed_dim,
386  int aligned_embed_dim,
387  int num_heads,
388  int num_kv_heads,
389  int head_dim,
390  int aligned_head_dim,
391  int cache_capacity,
392  float eps,
393  void *scratch,
394  int ith,
395  int nth)
396 {
397  const int H = num_heads;
398  const int KV = num_kv_heads;
399  const int AD = head_dim;
400  const int AE = aligned_embed_dim;
401  (void)embed_dim;
402 
403  /* Each thread handles a subset of heads */
404  const int heads_per_thread = (H + nth - 1) / nth;
405  const int h_start = ith * heads_per_thread;
406  const int h_end = (h_start + heads_per_thread < H) ? h_start + heads_per_thread : H;
407  const int my_heads = h_end - h_start;
408 
409  if (h_start >= H) return;
410 
411  /* Parse scratch buffer (shared across threads) */
412  float *scratch_ptr = (float *)scratch;
413 
414  float *rmsnorm_out = scratch_ptr;
415  scratch_ptr += AE;
416 
417  float *rstd_scratch = scratch_ptr;
418  scratch_ptr += AE;
419 
420  float *q = scratch_ptr;
421  scratch_ptr += H * AD;
422 
423  float *k = scratch_ptr;
424  scratch_ptr += KV * AD;
425 
426  float *v = scratch_ptr;
427  scratch_ptr += KV * AD;
428 
429  float *attn_out = scratch_ptr;
430  scratch_ptr += H * AD;
431 
432  block_q8_0 *x_q8_scratch = (block_q8_0 *)scratch_ptr;
433 
434  /* ========================================================================
435  * PHASE 1: Only thread 0 does RMSNorm and K/V projections
436  * These are shared across all heads.
437  * CALLER MUST BARRIER AFTER THIS PHASE.
438  * ======================================================================== */
439  if (ith == 0) {
440  rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
441 
442  gemv_q5_0_from_fp32(q, wq_q5_0, rmsnorm_out, bq, H * AD, AE, x_q8_scratch);
443  gemv_q5_0_from_fp32(k, wk_q5_0, rmsnorm_out, bk, KV * AD, AE, x_q8_scratch);
444  gemv_q8_0_from_fp32(v, wv_q8_0, rmsnorm_out, bv, KV * AD, AE, x_q8_scratch);
445 
446  apply_rope_inline(q, k, rope_cos, rope_sin, pos, H, KV, AD);
447 
448  /* Store K/V to cache */
449  const size_t kv_stride = (size_t)cache_capacity * AD;
450  for (int kv_idx = 0; kv_idx < KV; kv_idx++) {
451  float *k_cache = &kv_cache_k[kv_idx * kv_stride];
452  float *v_cache = &kv_cache_v[kv_idx * kv_stride];
453  const int offset = pos * AD;
454  for (int d = 0; d < AD; d++) {
455  k_cache[offset + d] = k[kv_idx * AD + d];
456  v_cache[offset + d] = v[kv_idx * AD + d];
457  }
458  }
459  }
460 
461  /* ========================================================================
462  * CALLER MUST BARRIER HERE
463  * All threads need to wait for thread 0 to finish projections
464  * ======================================================================== */
465 
466  /* ========================================================================
467  * PHASE 2: Each thread does attention for its heads only
468  * attention_forward_decode_head_major_gqa_flash expects:
469  * - q_token: pointer to start of Q for these heads
470  * - out_token: pointer to start of output for these heads
471  * - num_heads: number of heads THIS THREAD is processing
472  * ======================================================================== */
473  if (my_heads > 0) {
475  &q[h_start * AD], /* Q for this thread's heads */
476  kv_cache_k, kv_cache_v,
477  &attn_out[h_start * AD], /* Output for this thread's heads */
478  my_heads, /* Only my_heads, not H */
479  KV, /* Still need all KV heads for GQA */
480  pos + 1, cache_capacity, AD, aligned_head_dim);
481  }
482 
483  /* ========================================================================
484  * CALLER MUST BARRIER HERE
485  * Thread 0 needs all threads to finish attention before O projection
486  * ======================================================================== */
487 
488  /* ========================================================================
489  * PHASE 3: Thread 0 does O projection and residual add
490  * ======================================================================== */
491  if (ith == 0) {
492  /* Quantize full attention output for O projection */
493  quantize_row_q8_0(attn_out, x_q8_scratch, H * AD);
494 
495  const block_q5_0 *wo = (const block_q5_0 *)wo_q5_0;
496  const int blocks_per_row = (H * AD) / QK5_0;
497 
498  for (int e = 0; e < AE; e++) {
499  float dot;
500  vec_dot_q5_0_q8_0(H * AD, &dot, &wo[e * blocks_per_row], x_q8_scratch);
501  output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
502  }
503  }
504 }
Quantization block structures for weight-only quantization.
#define QK5_0
Definition: ckernel_quant.h:67
#define QK8_0
void mega_fused_attention_decode_q5_0_parallel_simd(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth)
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
void mega_fused_attention_decode_q5_0(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch)
Serial mega-fused attention decode kernel.
static void gemv_q5_0_from_fp32(float *out, const void *W_q5_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
static void apply_rope_inline(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int H, int KV, int AD)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd, int T, int D, int AD, float eps)
static void gemv_q8_0_from_fp32(float *out, const void *W_q8_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD)
Calculate scratch buffer size needed for the kernel.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.