← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention.h File Reference

Mega-Fused Attention Kernel. More...

#include <stdint.h>
#include "ckernel_dtype.h"

Go to the source code of this file.

Macros

#define MEGA_FUSE_KV_TILE   64
 
#define MEGA_FUSE_Q_TILE   64
 

Functions

void mega_fuse_get_optimal_tiles (int *q_tile, int *kv_tile, int head_dim)
 Get optimal tile sizes for current CPU. More...
 
void mega_fuse_report_stats (int hidden, int num_layers, int seq_len)
 Report memory savings from mega-fusion. More...
 
void mega_fuse_rmsnorm_qkv (float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
 Phase 1: Fused RMSNorm + QKV (intermediates in registers) More...
 
void mega_fuse_rmsnorm_qkv_rope (float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *W_qkv, const float *b_qkv, const float *rope_cos, const float *rope_sin, int pos, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
 Phase 2: Fused RMSNorm + QKV + RoPE. More...
 
void mega_fused_attention_decode (float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
 Mega-fused attention for decode mode (single token) More...
 
void mega_fused_attention_prefill (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
 Mega-fused attention for prefill mode (multiple tokens) More...
 
void mega_fused_attention_prefill_q8_0 (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
 Mega-fused prefill attention kernel (Q8_0 out-proj) More...
 
size_t mega_fused_attention_prefill_q8_0_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 Get scratch buffer size for mega_fused_attention_prefill_q8_0. More...
 
size_t mega_fused_attention_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 Get scratch buffer size for mega_fused_attention_prefill. More...
 
void mega_fused_outproj_mlp_prefill (float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
 Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill. More...
 
size_t mega_fused_outproj_mlp_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
 Get scratch buffer size for mega_fused_outproj_mlp_prefill. More...
 

Detailed Description

Mega-Fused Attention Kernel.

Holy grail fusion: RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual

All intermediates stay in registers/L1/L2. Single DRAM round-trip.

Memory Reduction: Before: ~32KB intermediates per layer (stack/heap) After: ~8KB total (input + output only) Reduction: 4-5× per layer, ~100× for full model

Performance Target: Move from memory-bound to compute-bound Expected speedup: 5-10× for attention-heavy workloads

Definition in file mega_fused_attention.h.

Macro Definition Documentation

◆ MEGA_FUSE_KV_TILE

#define MEGA_FUSE_KV_TILE   64

Definition at line 36 of file mega_fused_attention.h.

◆ MEGA_FUSE_Q_TILE

#define MEGA_FUSE_Q_TILE   64

Definition at line 32 of file mega_fused_attention.h.

Function Documentation

◆ mega_fuse_get_optimal_tiles()

void mega_fuse_get_optimal_tiles ( int *  q_tile,
int *  kv_tile,
int  head_dim 
)

Get optimal tile sizes for current CPU.

◆ mega_fuse_report_stats()

void mega_fuse_report_stats ( int  hidden,
int  num_layers,
int  seq_len 
)

Report memory savings from mega-fusion.

◆ mega_fuse_rmsnorm_qkv()

void mega_fuse_rmsnorm_qkv ( float *  q_out,
float *  k_out,
float *  v_out,
const float *  input,
const float *  gamma,
const float *  W_qkv,
const float *  b_qkv,
int  hidden,
int  num_heads,
int  num_kv_heads,
int  head_dim,
float  eps 
)

Phase 1: Fused RMSNorm + QKV (intermediates in registers)

Simpler step: Just fuse RMSNorm with QKV projection. Q/K/V stay in stack buffers, not DRAM.

◆ mega_fuse_rmsnorm_qkv_rope()

void mega_fuse_rmsnorm_qkv_rope ( float *  q_out,
float *  k_out,
float *  v_out,
const float *  input,
const float *  gamma,
const float *  W_qkv,
const float *  b_qkv,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  hidden,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  max_seq,
float  eps 
)

Phase 2: Fused RMSNorm + QKV + RoPE.

Q/K stay in output buffers, RoPE applied in-place.

◆ mega_fused_attention_decode()

void mega_fused_attention_decode ( float *  output,
const float *  input,
const float *  residual,
const float *  ln1_gamma,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
const float *  wo,
const float *  bo,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  cache_capacity,
float  eps 
)

Mega-fused attention for decode mode (single token)

This is the "holy grail" - all operations fused, no intermediates to DRAM.

Parameters
outputOutput [aligned_embed_dim] (includes residual add)
inputInput [aligned_embed_dim]
residualResidual input [aligned_embed_dim] (or NULL)
ln1_gammaRMSNorm gamma [embed_dim]
wqQ weights (quantized) [num_heads * aligned_head_dim * aligned_embed_dim]
bqQ bias [num_heads * aligned_head_dim] (or NULL)
wq_dtQ weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32)
wkK weights (quantized) [num_kv_heads * aligned_head_dim * aligned_embed_dim]
bkK bias [num_kv_heads * aligned_head_dim] (or NULL)
wk_dtK weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32)
wvV weights (quantized) [num_kv_heads * aligned_head_dim * aligned_embed_dim]
bvV bias [num_kv_heads * aligned_head_dim] (or NULL)
wv_dtV weight dtype (CK_DT_Q5_0/CK_DT_Q8_0/CK_DT_FP32)
woOutput projection weights (quantized) [aligned_embed_dim * aligned_embed_dim]
boOutput bias [aligned_embed_dim] (or NULL)
wo_dtOutput weight dtype (CK_DT_Q5_0/CK_DT_FP32)
kv_cache_kKV cache for K [num_kv_heads * cache_capacity * aligned_head_dim]
kv_cache_vKV cache for V [num_kv_heads * cache_capacity * aligned_head_dim]
rope_cosRoPE cos [max_seq, head_dim/2]
rope_sinRoPE sin [max_seq, head_dim/2]
posCurrent position in sequence
embed_dimModel hidden dimension (unpadded)
aligned_embed_dimAligned hidden dimension
num_headsNumber of attention heads
num_kv_headsNumber of KV heads (for GQA)
head_dimHead dimension (unpadded)
aligned_head_dimAligned head dimension
cache_capacityKV cache capacity (stride in tokens)
epsRMSNorm epsilon
scratchScratch buffer from mega_fused_attention_prefill_scratch_size()

Mega-fused attention for decode mode (single token)

RMSNorm → QKV → RoPE → Flash Attn → OutProj + Residual

Definition at line 589 of file mega_fused_attention_avx.c.

611 {
612  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
613  !kv_cache_k || !kv_cache_v) {
614  return;
615  }
616  if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0 ||
617  num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
618  return;
619  }
620  if (pos < 0 || pos >= cache_capacity) {
621  return;
622  }
623  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
624  return;
625  }
626 
627  const size_t q_elems = (size_t)num_heads * (size_t)aligned_head_dim;
628  const size_t kv_elems = (size_t)num_kv_heads * (size_t)aligned_head_dim;
629 
630  float q_stack[MEGA_STACK_MAX];
631  float k_stack[MEGA_STACK_MAX];
632  float v_stack[MEGA_STACK_MAX];
633  float o_stack[MEGA_STACK_MAX];
634 
635  float *q = q_stack;
636  float *k = k_stack;
637  float *v = v_stack;
638  float *o = o_stack;
639 
640  int free_q = 0;
641  int free_k = 0;
642  int free_v = 0;
643  int free_o = 0;
644 
645  if (q_elems > MEGA_STACK_MAX) {
646  q = (float *)malloc(q_elems * sizeof(float));
647  if (!q) {
648  return;
649  }
650  free_q = 1;
651  }
652  if (kv_elems > MEGA_STACK_MAX) {
653  k = (float *)malloc(kv_elems * sizeof(float));
654  if (!k) {
655  if (free_q) free(q);
656  return;
657  }
658  v = (float *)malloc(kv_elems * sizeof(float));
659  if (!v) {
660  if (free_q) free(q);
661  free(k);
662  return;
663  }
664  free_k = 1;
665  free_v = 1;
666  }
667  if (q_elems > MEGA_STACK_MAX) {
668  o = (float *)malloc(q_elems * sizeof(float));
669  if (!o) {
670  if (free_q) free(q);
671  if (free_k) free(k);
672  if (free_v) free(v);
673  return;
674  }
675  free_o = 1;
676  }
677 
678  mega_fuse_rmsnorm_qkv_avx(q, k, v, input, ln1_gamma,
679  wq, bq, wk, bk, wv, bv,
680  embed_dim, aligned_embed_dim,
681  num_heads, num_kv_heads,
682  head_dim, aligned_head_dim, eps);
683 
684  if (rope_cos && rope_sin) {
685  mega_fuse_rope_inplace_avx(q, k, rope_cos, rope_sin, pos,
686  num_heads, num_kv_heads,
687  head_dim, aligned_head_dim);
688  }
689 
691  kv_cache_k, kv_cache_v,
692  num_kv_heads, pos,
693  cache_capacity,
694  head_dim, aligned_head_dim);
695 
696  mega_fuse_flash_attention_avx(o, q, kv_cache_k, kv_cache_v,
697  num_heads, num_kv_heads,
698  pos + 1, cache_capacity,
699  head_dim, aligned_head_dim);
700 
701  mega_fuse_output_proj_residual(o, wo, bo, residual, output,
702  embed_dim, aligned_embed_dim,
703  num_heads, head_dim, aligned_head_dim);
704 
705  if (free_q) free(q);
706  if (free_k) free(k);
707  if (free_v) free(v);
708  if (free_o) free(o);
709 }
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
static void mega_fuse_output_proj_residual(const float *attn_token, const float *wo, const float *bo, const float *residual, float *output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)
void mega_fuse_flash_attention_avx(float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)
Flash attention with online softmax (AVX version)
void mega_fuse_rmsnorm_qkv_avx(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)
Fused RMSNorm + QKV for decode (single token)
void mega_fuse_rope_inplace_avx(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)
Apply RoPE to Q and K (in-place, from L1)
#define MEGA_STACK_MAX

References kv_cache_write_head_major(), mega_fuse_flash_attention_avx(), mega_fuse_output_proj_residual(), mega_fuse_rmsnorm_qkv_avx(), mega_fuse_rope_inplace_avx(), and MEGA_STACK_MAX.

◆ mega_fused_attention_prefill()

void mega_fused_attention_prefill ( float *  output,
const float *  input,
const float *  residual,
const float *  ln1_gamma,
const void *  wq,
const float *  bq,
CKDataType  wq_dt,
const void *  wk,
const float *  bk,
CKDataType  wk_dt,
const void *  wv,
const float *  bv,
CKDataType  wv_dt,
const void *  wo,
const float *  bo,
CKDataType  wo_dt,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  start_pos,
int  tokens,
int  cache_capacity,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
float  eps,
void *  scratch 
)

Mega-fused attention for prefill mode (multiple tokens)

Parameters
outputOutput [tokens, aligned_embed_dim] (includes residual add)
inputInput [tokens, aligned_embed_dim]
residualResidual input [tokens, aligned_embed_dim] (or NULL)
ln1_gammaRMSNorm gamma [embed_dim]
wqQ weights [num_heads * aligned_head_dim * aligned_embed_dim]
bqQ bias [num_heads * aligned_head_dim] (or NULL)
wkK weights [num_kv_heads * aligned_head_dim * aligned_embed_dim]
bkK bias [num_kv_heads * aligned_head_dim] (or NULL)
wvV weights [num_kv_heads * aligned_head_dim * aligned_embed_dim]
bvV bias [num_kv_heads * aligned_head_dim] (or NULL)
woOutput projection weights [num_heads * aligned_embed_dim * aligned_head_dim]
boOutput bias [aligned_embed_dim] (or NULL)
kv_cache_kKV cache for K [num_kv_heads * cache_capacity * aligned_head_dim]
kv_cache_vKV cache for V [num_kv_heads * cache_capacity * aligned_head_dim]
rope_cosRoPE cos [max_seq, head_dim/2]
rope_sinRoPE sin [max_seq, head_dim/2]
start_posStarting position in KV cache
tokensNumber of tokens to process
cache_capacityKV cache capacity (stride in tokens)
embed_dimModel hidden dimension (unpadded)
aligned_embed_dimAligned hidden dimension
num_headsNumber of attention heads
num_kv_headsNumber of KV heads
head_dimHead dimension (unpadded)
aligned_head_dimAligned head dimension
epsRMSNorm epsilon

Definition at line 160 of file mega_fused_attention_prefill.c.

184 {
185  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
186  !kv_cache_k || !kv_cache_v || !scratch) {
187  return;
188  }
189  if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
190  head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
191  return;
192  }
193  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
194  return;
195  }
196  if (start_pos < 0 || start_pos + tokens > cache_capacity) {
197  return;
198  }
199 
200  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
201  (size_t)aligned_head_dim * sizeof(float);
202  const size_t attn_bytes = q_bytes;
203  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
204  const size_t qkv_scratch_bytes = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
205 
206  uint8_t *scratch_bytes = (uint8_t *)scratch;
207  float *q = (float *)scratch_bytes;
208  scratch_bytes += align_up_size(q_bytes, 64);
209  float *attn_out = (float *)scratch_bytes;
210  scratch_bytes += align_up_size(attn_bytes, 64);
211  float *proj_scratch = (float *)scratch_bytes;
212  scratch_bytes += align_up_size(proj_bytes, 64);
213  void *qkv_scratch = (void *)scratch_bytes;
214  (void)qkv_scratch_bytes;
215 
216  float *k_ptr = kv_cache_k + (size_t)start_pos * (size_t)aligned_head_dim;
217  float *v_ptr = kv_cache_v + (size_t)start_pos * (size_t)aligned_head_dim;
218 
219  if (wq_dt == CK_DT_FP32 && wk_dt == CK_DT_FP32 && wv_dt == CK_DT_FP32) {
221  ln1_gamma,
222  (const float *)wq, bq,
223  (const float *)wk, bk,
224  (const float *)wv, bv,
225  q,
226  k_ptr,
227  v_ptr,
228  tokens,
229  embed_dim,
230  aligned_embed_dim,
231  num_heads,
232  num_kv_heads,
233  head_dim,
234  aligned_head_dim,
235  cache_capacity,
236  eps,
237  qkv_scratch);
238  } else {
240  ln1_gamma,
241  wq, bq, wq_dt,
242  wk, bk, wk_dt,
243  wv, bv, wv_dt,
244  q,
245  k_ptr,
246  v_ptr,
247  tokens,
248  embed_dim,
249  aligned_embed_dim,
250  num_heads,
251  num_kv_heads,
252  head_dim,
253  aligned_head_dim,
254  cache_capacity,
255  eps,
256  qkv_scratch);
257  }
258 
259  if (rope_cos && rope_sin) {
261  k_ptr,
262  rope_cos,
263  rope_sin,
264  num_heads,
265  num_kv_heads,
266  tokens,
267  head_dim,
268  aligned_head_dim,
269  start_pos,
270  tokens,
271  cache_capacity);
272  }
273 
274  if (start_pos == 0) {
276  k_ptr,
277  v_ptr,
278  attn_out,
279  num_heads,
280  num_kv_heads,
281  tokens,
282  head_dim,
283  aligned_head_dim,
284  cache_capacity);
285  } else {
286  const float scale = 1.0f / sqrtf((float)head_dim);
287  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
288  const size_t kv_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
289 
290  for (int h = 0; h < num_heads; ++h) {
291  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
292  const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
293  const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
294 
295  for (int i = 0; i < tokens; ++i) {
296  const float *q_vec = q + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
297  float *out_vec = attn_out + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
298  attention_flash_decode(out_vec,
299  q_vec,
300  k_head,
301  v_head,
302  1,
303  start_pos + i + 1,
304  1,
305  aligned_head_dim,
306  scale);
307  }
308  }
309  }
310 
311  if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
312  return;
313  }
314 
315  if (wo_dt == CK_DT_Q5_0 &&
317  (aligned_head_dim % QK5_0) == 0 &&
318  (aligned_embed_dim % QK5_0) == 0) {
319  /* Quantized activations path: Q8_0 attn_out + Q5_0 weights. */
320  uint8_t *attn_q8 = (uint8_t *)q;
322  attn_q8,
323  tokens,
324  num_heads,
325  aligned_head_dim);
327  wo,
328  bo,
329  output,
330  tokens,
331  aligned_embed_dim,
332  num_heads,
333  aligned_head_dim);
334  } else if (wo_dt == CK_DT_Q5_0 &&
335  (aligned_head_dim % QK5_0) == 0 &&
336  (aligned_embed_dim % QK5_0) == 0) {
337  /* Head-major output projection with Q5_0 weights - no flatten needed */
339  wo,
340  bo,
341  output,
342  tokens,
343  aligned_embed_dim,
344  num_heads,
345  aligned_head_dim);
346  } else if (wo_dt == CK_DT_Q8_0 &&
347  (aligned_head_dim % QK8_0) == 0 &&
348  (aligned_embed_dim % QK8_0) == 0) {
349  /* Head-major output projection with Q8_0 weights - no flatten needed */
351  wo,
352  bo,
353  output,
354  tokens,
355  aligned_embed_dim,
356  num_heads,
357  aligned_head_dim);
358  } else {
359  /* Fallback: flatten then GEMM (slow path) */
360  flatten_head_major(attn_out,
361  proj_scratch,
362  tokens,
363  aligned_embed_dim,
364  num_heads,
365  aligned_head_dim);
366 
367  ck_gemm_nt_quant(proj_scratch,
368  wo,
369  bo,
370  output,
371  tokens,
372  aligned_embed_dim,
373  aligned_embed_dim,
374  wo_dt);
375  }
376 
377  if (residual) {
379  output,
380  output,
381  tokens,
382  aligned_embed_dim);
383  }
384 
385 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
#define QK5_0
Definition: ckernel_quant.h:67
#define QK8_0
static size_t align_up_size(size_t value, size_t align)
static void flatten_head_major(const float *attn_out, float *dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
static int ck_q8_0_outproj_enabled(void)

References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q5_0, CK_DT_Q8_0, ck_gemm_nt_head_major_q5_0(), ck_gemm_nt_head_major_q8_0(), ck_gemm_nt_quant(), ck_q8_0_outproj_enabled(), ck_residual_add_token_major(), flatten_head_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q5_0_q8_0(), QK5_0, QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().

◆ mega_fused_attention_prefill_q8_0()

void mega_fused_attention_prefill_q8_0 ( float *  output,
const float *  input,
const float *  residual,
const float *  ln1_gamma,
const void *  wq,
const float *  bq,
CKDataType  wq_dt,
const void *  wk,
const float *  bk,
CKDataType  wk_dt,
const void *  wv,
const float *  bv,
CKDataType  wv_dt,
const void *  wo,
const float *  bo,
CKDataType  wo_dt,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  start_pos,
int  tokens,
int  cache_capacity,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
float  eps,
void *  scratch 
)

Mega-fused prefill attention kernel (Q8_0 out-proj)

Same layout and scratch requirements as mega_fused_attention_prefill.

Definition at line 105 of file mega_fused_attention_prefill_q8_0.c.

129 {
130  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
131  !kv_cache_k || !kv_cache_v || !scratch) {
132  return;
133  }
134  if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
135  head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
136  return;
137  }
138  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
139  return;
140  }
141  if (start_pos < 0 || start_pos + tokens > cache_capacity) {
142  return;
143  }
144  if (wo_dt != CK_DT_Q8_0) {
145  return;
146  }
147  if ((aligned_head_dim % QK8_0) != 0 || (aligned_embed_dim % QK8_0) != 0) {
148  return;
149  }
150 
151  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
152  (size_t)aligned_head_dim * sizeof(float);
153  const size_t attn_bytes = q_bytes;
154  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
155  const size_t qkv_scratch_bytes = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
156 
157  uint8_t *scratch_bytes = (uint8_t *)scratch;
158  float *q = (float *)scratch_bytes;
159  scratch_bytes += align_up_size(q_bytes, 64);
160  float *attn_out = (float *)scratch_bytes;
161  scratch_bytes += align_up_size(attn_bytes, 64);
162  float *proj_scratch = (float *)scratch_bytes;
163  scratch_bytes += align_up_size(proj_bytes, 64);
164  void *qkv_scratch = (void *)scratch_bytes;
165  (void)qkv_scratch_bytes;
166  (void)proj_scratch;
167 
168  float *k_ptr = kv_cache_k + (size_t)start_pos * (size_t)aligned_head_dim;
169  float *v_ptr = kv_cache_v + (size_t)start_pos * (size_t)aligned_head_dim;
170 
171  if (wq_dt == CK_DT_FP32 && wk_dt == CK_DT_FP32 && wv_dt == CK_DT_FP32) {
173  ln1_gamma,
174  (const float *)wq, bq,
175  (const float *)wk, bk,
176  (const float *)wv, bv,
177  q,
178  k_ptr,
179  v_ptr,
180  tokens,
181  embed_dim,
182  aligned_embed_dim,
183  num_heads,
184  num_kv_heads,
185  head_dim,
186  aligned_head_dim,
187  cache_capacity,
188  eps,
189  qkv_scratch);
190  } else {
192  ln1_gamma,
193  wq, bq, wq_dt,
194  wk, bk, wk_dt,
195  wv, bv, wv_dt,
196  q,
197  k_ptr,
198  v_ptr,
199  tokens,
200  embed_dim,
201  aligned_embed_dim,
202  num_heads,
203  num_kv_heads,
204  head_dim,
205  aligned_head_dim,
206  cache_capacity,
207  eps,
208  qkv_scratch);
209  }
210 
211  if (rope_cos && rope_sin) {
213  k_ptr,
214  rope_cos,
215  rope_sin,
216  num_heads,
217  num_kv_heads,
218  tokens,
219  head_dim,
220  aligned_head_dim,
221  start_pos,
222  tokens,
223  cache_capacity);
224  }
225 
226  if (start_pos == 0) {
228  k_ptr,
229  v_ptr,
230  attn_out,
231  num_heads,
232  num_kv_heads,
233  tokens,
234  head_dim,
235  aligned_head_dim,
236  cache_capacity);
237  } else {
238  const float scale = 1.0f / sqrtf((float)head_dim);
239  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
240  const size_t kv_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
241 
242  for (int h = 0; h < num_heads; ++h) {
243  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
244  const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
245  const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
246 
247  for (int i = 0; i < tokens; ++i) {
248  const float *q_vec = q + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
249  float *out_vec = attn_out + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
250  attention_flash_decode(out_vec,
251  q_vec,
252  k_head,
253  v_head,
254  1,
255  start_pos + i + 1,
256  1,
257  aligned_head_dim,
258  scale);
259  }
260  }
261  }
262 
263  if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
264  return;
265  }
266 
267  /* Quantized activations path: Q8_0 attn_out + Q8_0 weights. */
268  {
269  uint8_t *attn_q8 = (uint8_t *)q;
271  attn_q8,
272  tokens,
273  num_heads,
274  aligned_head_dim);
276  wo,
277  bo,
278  output,
279  tokens,
280  aligned_embed_dim,
281  num_heads,
282  aligned_head_dim);
283  }
284 
285  if (residual) {
287  output,
288  output,
289  tokens,
290  aligned_embed_dim);
291  }
292 }
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)

References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q8_0, ck_residual_add_token_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q8_0_q8_0(), QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().

◆ mega_fused_attention_prefill_q8_0_scratch_size()

size_t mega_fused_attention_prefill_q8_0_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Get scratch buffer size for mega_fused_attention_prefill_q8_0.

Definition at line 84 of file mega_fused_attention_prefill_q8_0.c.

88 {
89  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
90  return 0;
91  }
92 
93  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
94  (size_t)aligned_head_dim * sizeof(float);
95  const size_t attn_bytes = q_bytes;
96  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
97  const size_t qkv_scratch = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
98 
99  return align_up_size(q_bytes, 64) +
100  align_up_size(attn_bytes, 64) +
101  align_up_size(proj_bytes, 64) +
102  align_up_size(qkv_scratch, 64);
103 }

References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().

◆ mega_fused_attention_prefill_scratch_size()

size_t mega_fused_attention_prefill_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Get scratch buffer size for mega_fused_attention_prefill.

Definition at line 139 of file mega_fused_attention_prefill.c.

143 {
144  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
145  return 0;
146  }
147 
148  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
149  (size_t)aligned_head_dim * sizeof(float);
150  const size_t attn_bytes = q_bytes;
151  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
152  const size_t qkv_scratch = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
153 
154  return align_up_size(q_bytes, 64) +
155  align_up_size(attn_bytes, 64) +
156  align_up_size(proj_bytes, 64) +
157  align_up_size(qkv_scratch, 64);
158 }

References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().

◆ mega_fused_outproj_mlp_prefill()

void mega_fused_outproj_mlp_prefill ( float *  output,
const float *  attn_out,
const float *  residual,
const float *  ln2_gamma,
const void *  wo,
const float *  bo,
CKDataType  wo_dt,
const void *  w1,
const float *  b1,
CKDataType  w1_dt,
const void *  w2,
const float *  b2,
CKDataType  w2_dt,
int  tokens,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
int  intermediate_dim,
int  aligned_intermediate_dim,
float  eps,
void *  scratch 
)

Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.

Uses head-major attention output and quantized out-proj (Q5_0/Q8_0 weights).

Definition at line 184 of file mega_fused_outproj_mlp_prefill.c.

201 {
202  if (!output || !attn_out || !residual || !ln2_gamma ||
203  !wo || !w1 || !w2 || !scratch) {
204  return;
205  }
206  if (tokens <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
207  num_heads <= 0 || aligned_head_dim <= 0 ||
208  intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
209  return;
210  }
211  if (aligned_embed_dim < embed_dim || aligned_head_dim <= 0 ||
212  aligned_intermediate_dim < intermediate_dim) {
213  return;
214  }
215  if (aligned_embed_dim != num_heads * aligned_head_dim) {
216  return;
217  }
218  if ((aligned_embed_dim % 32) != 0 || (aligned_head_dim % 32) != 0) {
219  return;
220  }
221  if ((aligned_intermediate_dim % QK_K) != 0) {
222  return;
223  }
224  if (wo_dt != CK_DT_Q5_0 && wo_dt != CK_DT_Q8_0) {
225  return;
226  }
227  if (w1_dt != CK_DT_Q5_0 && w1_dt != CK_DT_Q8_0) {
228  return;
229  }
230  if (w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K) {
231  return;
232  }
233 
234  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
235  (size_t)aligned_head_dim);
236  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
237  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
238  const size_t ln2_bytes = h1_bytes;
239 
240  uint8_t *scratch_bytes = (uint8_t *)scratch;
241  uint8_t *attn_q8 = scratch_bytes;
242  scratch_bytes += align_up_size(attn_q8_bytes, 64);
243  float *h1 = (float *)scratch_bytes;
244  scratch_bytes += align_up_size(h1_bytes, 64);
245  float *ln2_out = (float *)scratch_bytes;
246  scratch_bytes += align_up_size(ln2_bytes, 64);
247  void *mlp_scratch = (void *)scratch_bytes;
248 
250  attn_q8,
251  tokens,
252  num_heads,
253  aligned_head_dim);
254 
255  if (wo_dt == CK_DT_Q8_0) {
257  wo,
258  bo,
259  h1,
260  tokens,
261  aligned_embed_dim,
262  num_heads,
263  aligned_head_dim);
264  } else {
266  wo,
267  bo,
268  h1,
269  tokens,
270  aligned_embed_dim,
271  num_heads,
272  aligned_head_dim);
273  }
274 
275  for (int t = 0; t < tokens; ++t) {
276  const float *res_row = residual + (size_t)t * (size_t)aligned_embed_dim;
277  float *h1_row = h1 + (size_t)t * (size_t)aligned_embed_dim;
278  add_inplace_f32(h1_row, res_row, aligned_embed_dim);
279  }
280 
281  rmsnorm_forward(h1,
282  ln2_gamma,
283  ln2_out,
284  NULL,
285  tokens,
286  embed_dim,
287  aligned_embed_dim,
288  eps);
289 
291  w1,
292  b1,
293  w1_dt,
294  w2,
295  b2,
296  w2_dt,
297  output,
298  tokens,
299  embed_dim,
300  aligned_embed_dim,
301  intermediate_dim,
302  aligned_intermediate_dim,
303  mlp_scratch);
304 
305  for (int t = 0; t < tokens; ++t) {
306  const float *h1_row = h1 + (size_t)t * (size_t)aligned_embed_dim;
307  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
308  add_inplace_f32(out_row, h1_row, aligned_embed_dim);
309  }
310 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void add_inplace_f32(float *a, const float *b, size_t n)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
#define QK_K
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)

References add_inplace_f32(), align_up_size(), CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q6_K, CK_DT_Q8_0, ck_dtype_row_bytes(), fused_mlp_swiglu_prefill_w1w2_quant(), out_proj_head_major_q5_0_q8_0(), out_proj_head_major_q8_0_q8_0(), QK_K, quantize_attn_out_head_major_q8_0(), and rmsnorm_forward().

◆ mega_fused_outproj_mlp_prefill_scratch_size()

size_t mega_fused_outproj_mlp_prefill_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
int  aligned_intermediate_dim 
)

Get scratch buffer size for mega_fused_outproj_mlp_prefill.

Definition at line 159 of file mega_fused_outproj_mlp_prefill.c.

164 {
165  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 ||
166  aligned_head_dim <= 0 || aligned_intermediate_dim <= 0) {
167  return 0;
168  }
169 
170  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
171  (size_t)aligned_head_dim);
172  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
173  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
174  const size_t ln2_bytes = h1_bytes;
175  const size_t mlp_scratch = fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(
176  aligned_embed_dim, aligned_intermediate_dim);
177 
178  return align_up_size(attn_q8_bytes, 64) +
179  align_up_size(h1_bytes, 64) +
180  align_up_size(ln2_bytes, 64) +
181  align_up_size(mlp_scratch, 64);
182 }
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.

References align_up_size(), CK_DT_Q8_0, ck_dtype_row_bytes(), and fused_mlp_swiglu_prefill_w1w2_quant_scratch_size().