← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_decode_q5_0.h File Reference

Mega-fused attention decode with Q5_0 weights - Header. More...

Go to the source code of this file.

Functions

void mega_fused_attention_decode_q5_0 (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch)
 Serial mega-fused attention decode kernel. More...
 
void mega_fused_attention_decode_q5_0_parallel_simd (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth)
 Parallel SIMD mega-fused attention decode kernel (threadpool-aware) More...
 
int mega_fused_attention_decode_scratch_size (int AE, int H, int KV, int AD)
 Calculate scratch buffer size needed for the kernel. More...
 

Detailed Description

Mega-fused attention decode with Q5_0 weights - Header.

This header declares the mega-fused attention decode kernel that combines 9 separate operations into a single fused kernel call:

  1. RMSNorm
  2. Q projection (Q5_0) with bias
  3. K projection (Q5_0) with bias
  4. V projection (Q8_0) with bias
  5. RoPE application
  6. KV cache store
  7. Flash attention decode (GQA-aware)
  8. O projection (Q5_0) with bias
  9. Residual add

Definition in file mega_fused_attention_decode_q5_0.h.

Function Documentation

◆ mega_fused_attention_decode_q5_0()

void mega_fused_attention_decode_q5_0 ( float *  output,
const float *  input,
const float *  residual,
const void *  wq_q5_0,
const void *  wk_q5_0,
const void *  wv_q8_0,
const void *  wo_q5_0,
const float *  ln_gamma,
const float *  bq,
const float *  bk,
const float *  bv,
const float *  bo,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  cache_capacity,
float  eps,
void *  scratch 
)

Serial mega-fused attention decode kernel.

Parameters
outputOutput [AE] (final result, after residual add)
inputInput activation [AE]
residualResidual input for add [AE]
wq_q5_0Q projection weights [H*AD, AE] Q5_0
wk_q5_0K projection weights [KV*AD, AE] Q5_0
wv_q8_0V projection weights [KV*AD, AE] Q8_0
wo_q5_0O projection weights [AE, H*AD] Q5_0
ln_gammaRMSNorm gamma [AE]
bqQ bias [H*AD] or NULL
bkK bias [KV*AD] or NULL
bvV bias [KV*AD] or NULL
boO bias [AE] or NULL
kv_cache_kK cache [KV, max_T, AD]
kv_cache_vV cache [KV, max_T, AD]
rope_cosRoPE cos [max_T, D]
rope_sinRoPE sin [max_T, D]
posCurrent position (0-indexed)
embed_dimOriginal embedding dimension E
aligned_embed_dimAligned embedding dimension AE
num_headsNumber of query heads H
num_kv_headsNumber of key/value heads KV
head_dimHead dimension AD
aligned_head_dimAligned head dimension AAD
cache_capacityMaximum cache capacity max_T
epsRMSNorm epsilon
scratchScratch buffer (>= scratch_size bytes)

Definition at line 222 of file mega_fused_attention_decode_q5_0.c.

249 {
250  const int H = num_heads;
251  const int KV = num_kv_heads;
252  const int AD = head_dim;
253  const int AE = aligned_embed_dim;
254  (void)embed_dim; /* Unused but kept for API consistency */
255 
256  /* Parse scratch buffer - all allocations from scratch, no VLAs */
257  float *scratch_ptr = (float *)scratch;
258 
259  float *rmsnorm_out = scratch_ptr;
260  scratch_ptr += AE;
261 
262  float *rstd_scratch = scratch_ptr; /* For rmsnorm rstd output - avoids VLA */
263  scratch_ptr += AE;
264 
265  float *q = scratch_ptr;
266  scratch_ptr += H * AD;
267 
268  float *k = scratch_ptr;
269  scratch_ptr += KV * AD;
270 
271  float *v = scratch_ptr;
272  scratch_ptr += KV * AD;
273 
274  float *attn_out = scratch_ptr;
275  scratch_ptr += H * AD;
276 
277  block_q8_0 *x_q8_scratch = (block_q8_0 *)scratch_ptr;
278 
279  const int q_size = H * AD;
280  const int k_size = KV * AD;
281  const int v_size = KV * AD;
282 
283  /* ========================================================================
284  * STEP 1: RMSNorm
285  * Correct signature: rmsnorm_forward(in, gamma, out, rstd, T, D, AD, eps)
286  * T=1 (single token), D=AE (full embed dim for norm)
287  * ======================================================================== */
288  rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
289 
290  /* ========================================================================
291  * STEP 2-4: Q, K, V projections (fused with quantization)
292  * Use scratch buffer for quantized input
293  * ======================================================================== */
294  gemv_q5_0_from_fp32(q, wq_q5_0, rmsnorm_out, bq, q_size, AE, x_q8_scratch);
295  gemv_q5_0_from_fp32(k, wk_q5_0, rmsnorm_out, bk, k_size, AE, x_q8_scratch);
296  gemv_q8_0_from_fp32(v, wv_q8_0, rmsnorm_out, bv, v_size, AE, x_q8_scratch);
297 
298  /* ========================================================================
299  * STEP 5: Apply RoPE
300  * ======================================================================== */
301  apply_rope_inline(q, k, rope_cos, rope_sin, pos, H, KV, AD);
302 
303  /* ========================================================================
304  * STEP 6: Store K and V to cache
305  * Cache layout: [KV, cache_capacity, AD]
306  * ======================================================================== */
307  const size_t kv_stride = (size_t)cache_capacity * AD;
308  for (int kv = 0; kv < KV; kv++) {
309  float *k_cache = &kv_cache_k[kv * kv_stride];
310  float *v_cache = &kv_cache_v[kv * kv_stride];
311  const float *k_src = &k[kv * AD];
312  const float *v_src = &v[kv * AD];
313  const int offset = pos * AD;
314  for (int d = 0; d < AD; d++) {
315  k_cache[offset + d] = k_src[d];
316  v_cache[offset + d] = v_src[d];
317  }
318  }
319 
320  /* ========================================================================
321  * STEP 7: Flash attention decode (GQA-aware variant)
322  * attention_forward_decode_head_major_gqa_flash handles H != KV correctly
323  * It maps each of H heads to one of KV KV heads via: kv_head = h * KV / H
324  * ======================================================================== */
326  q, kv_cache_k, kv_cache_v,
327  attn_out, H, KV, pos + 1, cache_capacity, AD, aligned_head_dim);
328 
329  /* ========================================================================
330  * STEP 8: O projection (Q5_0 weights) with bias and residual add
331  *
332  * attn_out layout: [H * AD] flattened
333  * wo_q5_0 layout: [AE, H*AD] - row e has H*AD input features
334  *
335  * O projection: output[e] = dot(wo[e], attn_out) + bias[e] + residual[e]
336  *
337  * Using vec_dot_q5_0_q8_0 for efficient quantized dot product.
338  * ======================================================================== */
339 
340  /* Quantize attention output to Q8_0 for GEMV */
341  quantize_row_q8_0(attn_out, x_q8_scratch, H * AD);
342 
343  const block_q5_0 *wo = (const block_q5_0 *)wo_q5_0;
344  const int blocks_per_row = (H * AD) / QK5_0;
345 
346  for (int e = 0; e < AE; e++) {
347  float dot;
348  vec_dot_q5_0_q8_0(H * AD, &dot, &wo[e * blocks_per_row], x_q8_scratch);
349  output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
350  }
351 }
#define QK5_0
Definition: ckernel_quant.h:67
static void gemv_q5_0_from_fp32(float *out, const void *W_q5_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
static void apply_rope_inline(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int H, int KV, int AD)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd, int T, int D, int AD, float eps)
static void gemv_q8_0_from_fp32(float *out, const void *W_q8_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().

◆ mega_fused_attention_decode_q5_0_parallel_simd()

void mega_fused_attention_decode_q5_0_parallel_simd ( float *  output,
const float *  input,
const float *  residual,
const void *  wq_q5_0,
const void *  wk_q5_0,
const void *  wv_q8_0,
const void *  wo_q5_0,
const float *  ln_gamma,
const float *  bq,
const float *  bk,
const float *  bv,
const float *  bo,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  cache_capacity,
float  eps,
void *  scratch,
int  ith,
int  nth 
)

Parallel SIMD mega-fused attention decode kernel (threadpool-aware)

Parallelizes across attention heads using (ith, nth) pattern. Each thread processes a subset of heads.

IMPORTANT: Caller must ensure barrier sync between phases: Phase 1 (ith==0 only): RMSNorm, Q/K/V projection, RoPE, KV cache store – BARRIER – Phase 2 (all threads): Attention for assigned heads – BARRIER – Phase 3 (ith==0 only): O projection and residual add

Parameters
ithThread index (0 to nth-1)
nthTotal number of threads (other parameters same as serial version)

Definition at line 367 of file mega_fused_attention_decode_q5_0.c.

396 {
397  const int H = num_heads;
398  const int KV = num_kv_heads;
399  const int AD = head_dim;
400  const int AE = aligned_embed_dim;
401  (void)embed_dim;
402 
403  /* Each thread handles a subset of heads */
404  const int heads_per_thread = (H + nth - 1) / nth;
405  const int h_start = ith * heads_per_thread;
406  const int h_end = (h_start + heads_per_thread < H) ? h_start + heads_per_thread : H;
407  const int my_heads = h_end - h_start;
408 
409  if (h_start >= H) return;
410 
411  /* Parse scratch buffer (shared across threads) */
412  float *scratch_ptr = (float *)scratch;
413 
414  float *rmsnorm_out = scratch_ptr;
415  scratch_ptr += AE;
416 
417  float *rstd_scratch = scratch_ptr;
418  scratch_ptr += AE;
419 
420  float *q = scratch_ptr;
421  scratch_ptr += H * AD;
422 
423  float *k = scratch_ptr;
424  scratch_ptr += KV * AD;
425 
426  float *v = scratch_ptr;
427  scratch_ptr += KV * AD;
428 
429  float *attn_out = scratch_ptr;
430  scratch_ptr += H * AD;
431 
432  block_q8_0 *x_q8_scratch = (block_q8_0 *)scratch_ptr;
433 
434  /* ========================================================================
435  * PHASE 1: Only thread 0 does RMSNorm and K/V projections
436  * These are shared across all heads.
437  * CALLER MUST BARRIER AFTER THIS PHASE.
438  * ======================================================================== */
439  if (ith == 0) {
440  rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
441 
442  gemv_q5_0_from_fp32(q, wq_q5_0, rmsnorm_out, bq, H * AD, AE, x_q8_scratch);
443  gemv_q5_0_from_fp32(k, wk_q5_0, rmsnorm_out, bk, KV * AD, AE, x_q8_scratch);
444  gemv_q8_0_from_fp32(v, wv_q8_0, rmsnorm_out, bv, KV * AD, AE, x_q8_scratch);
445 
446  apply_rope_inline(q, k, rope_cos, rope_sin, pos, H, KV, AD);
447 
448  /* Store K/V to cache */
449  const size_t kv_stride = (size_t)cache_capacity * AD;
450  for (int kv_idx = 0; kv_idx < KV; kv_idx++) {
451  float *k_cache = &kv_cache_k[kv_idx * kv_stride];
452  float *v_cache = &kv_cache_v[kv_idx * kv_stride];
453  const int offset = pos * AD;
454  for (int d = 0; d < AD; d++) {
455  k_cache[offset + d] = k[kv_idx * AD + d];
456  v_cache[offset + d] = v[kv_idx * AD + d];
457  }
458  }
459  }
460 
461  /* ========================================================================
462  * CALLER MUST BARRIER HERE
463  * All threads need to wait for thread 0 to finish projections
464  * ======================================================================== */
465 
466  /* ========================================================================
467  * PHASE 2: Each thread does attention for its heads only
468  * attention_forward_decode_head_major_gqa_flash expects:
469  * - q_token: pointer to start of Q for these heads
470  * - out_token: pointer to start of output for these heads
471  * - num_heads: number of heads THIS THREAD is processing
472  * ======================================================================== */
473  if (my_heads > 0) {
475  &q[h_start * AD], /* Q for this thread's heads */
476  kv_cache_k, kv_cache_v,
477  &attn_out[h_start * AD], /* Output for this thread's heads */
478  my_heads, /* Only my_heads, not H */
479  KV, /* Still need all KV heads for GQA */
480  pos + 1, cache_capacity, AD, aligned_head_dim);
481  }
482 
483  /* ========================================================================
484  * CALLER MUST BARRIER HERE
485  * Thread 0 needs all threads to finish attention before O projection
486  * ======================================================================== */
487 
488  /* ========================================================================
489  * PHASE 3: Thread 0 does O projection and residual add
490  * ======================================================================== */
491  if (ith == 0) {
492  /* Quantize full attention output for O projection */
493  quantize_row_q8_0(attn_out, x_q8_scratch, H * AD);
494 
495  const block_q5_0 *wo = (const block_q5_0 *)wo_q5_0;
496  const int blocks_per_row = (H * AD) / QK5_0;
497 
498  for (int e = 0; e < AE; e++) {
499  float dot;
500  vec_dot_q5_0_q8_0(H * AD, &dot, &wo[e * blocks_per_row], x_q8_scratch);
501  output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
502  }
503  }
504 }

References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().

◆ mega_fused_attention_decode_scratch_size()

int mega_fused_attention_decode_scratch_size ( int  AE,
int  H,
int  KV,
int  AD 
)

Calculate scratch buffer size needed for the kernel.

Parameters
AEAligned embedding dimension (multiple of 64)
HNumber of query heads
KVNumber of key/value heads
ADHead dimension
Returns
Size in bytes needed for scratch buffer

Definition at line 176 of file mega_fused_attention_decode_q5_0.c.

176  {
177  /* Need: 1x AE for RMSNorm output
178  1x AE for RMSNorm rstd (avoid VLA)
179  1x H*AD for Q
180  1x KV*AD for K
181  1x KV*AD for V
182  1x H*AD for attention output
183  1x max(AE, H*AD)/QK8_0 * sizeof(block_q8_0) for GEMV scratch
184  */
185  int max_input_dim = (AE > H * AD) ? AE : H * AD;
186  int q8_blocks = (max_input_dim + QK8_0 - 1) / QK8_0;
187  return (int)(sizeof(float) * (AE + AE + H * AD + 2 * KV * AD + H * AD)
188  + q8_blocks * sizeof(block_q8_0));
189 }
#define QK8_0

References QK8_0.