← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_prefill_q8_0.c File Reference

Mega-fused prefill attention kernel with Q8_0 out-proj. More...

#include "ckernel_engine.h"
#include "ckernel_orchestration.h"
#include "ckernel_quant.h"
#include <math.h>
#include <string.h>

Go to the source code of this file.

Functions

static size_t align_up_size (size_t value, size_t align)
 
void mega_fused_attention_prefill_q8_0 (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
 Mega-fused prefill attention kernel (Q8_0 out-proj) More...
 
size_t mega_fused_attention_prefill_q8_0_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 Get scratch buffer size for mega_fused_attention_prefill_q8_0. More...
 
static void out_proj_head_major_q8_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void quantize_attn_out_head_major_q8_0 (const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
 

Detailed Description

Mega-fused prefill attention kernel with Q8_0 out-proj.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. NO memcpy for layout - use strided access, not copies
  4. API must define: inputs, outputs, workspace, and memory layouts
  5. Pure computation - deterministic, no side effects

RMSNorm → QKV → RoPE → Flash Attention → Q8_0 OutProj + Residual Writes K/V directly into the KV cache layout (stride = cache_capacity).

Definition in file mega_fused_attention_prefill_q8_0.c.

Function Documentation

◆ align_up_size()

static size_t align_up_size ( size_t  value,
size_t  align 
)
static

Definition at line 24 of file mega_fused_attention_prefill_q8_0.c.

25 {
26  return (value + align - 1) & ~(align - 1);
27 }

Referenced by mega_fused_attention_prefill_q8_0(), and mega_fused_attention_prefill_q8_0_scratch_size().

◆ mega_fused_attention_prefill_q8_0()

void mega_fused_attention_prefill_q8_0 ( float *  output,
const float *  input,
const float *  residual,
const float *  ln1_gamma,
const void *  wq,
const float *  bq,
CKDataType  wq_dt,
const void *  wk,
const float *  bk,
CKDataType  wk_dt,
const void *  wv,
const float *  bv,
CKDataType  wv_dt,
const void *  wo,
const float *  bo,
CKDataType  wo_dt,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  start_pos,
int  tokens,
int  cache_capacity,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
float  eps,
void *  scratch 
)

Mega-fused prefill attention kernel (Q8_0 out-proj)

Same layout and scratch requirements as mega_fused_attention_prefill.

Definition at line 105 of file mega_fused_attention_prefill_q8_0.c.

129 {
130  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
131  !kv_cache_k || !kv_cache_v || !scratch) {
132  return;
133  }
134  if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
135  head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
136  return;
137  }
138  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
139  return;
140  }
141  if (start_pos < 0 || start_pos + tokens > cache_capacity) {
142  return;
143  }
144  if (wo_dt != CK_DT_Q8_0) {
145  return;
146  }
147  if ((aligned_head_dim % QK8_0) != 0 || (aligned_embed_dim % QK8_0) != 0) {
148  return;
149  }
150 
151  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
152  (size_t)aligned_head_dim * sizeof(float);
153  const size_t attn_bytes = q_bytes;
154  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
155  const size_t qkv_scratch_bytes = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
156 
157  uint8_t *scratch_bytes = (uint8_t *)scratch;
158  float *q = (float *)scratch_bytes;
159  scratch_bytes += align_up_size(q_bytes, 64);
160  float *attn_out = (float *)scratch_bytes;
161  scratch_bytes += align_up_size(attn_bytes, 64);
162  float *proj_scratch = (float *)scratch_bytes;
163  scratch_bytes += align_up_size(proj_bytes, 64);
164  void *qkv_scratch = (void *)scratch_bytes;
165  (void)qkv_scratch_bytes;
166  (void)proj_scratch;
167 
168  float *k_ptr = kv_cache_k + (size_t)start_pos * (size_t)aligned_head_dim;
169  float *v_ptr = kv_cache_v + (size_t)start_pos * (size_t)aligned_head_dim;
170 
171  if (wq_dt == CK_DT_FP32 && wk_dt == CK_DT_FP32 && wv_dt == CK_DT_FP32) {
173  ln1_gamma,
174  (const float *)wq, bq,
175  (const float *)wk, bk,
176  (const float *)wv, bv,
177  q,
178  k_ptr,
179  v_ptr,
180  tokens,
181  embed_dim,
182  aligned_embed_dim,
183  num_heads,
184  num_kv_heads,
185  head_dim,
186  aligned_head_dim,
187  cache_capacity,
188  eps,
189  qkv_scratch);
190  } else {
192  ln1_gamma,
193  wq, bq, wq_dt,
194  wk, bk, wk_dt,
195  wv, bv, wv_dt,
196  q,
197  k_ptr,
198  v_ptr,
199  tokens,
200  embed_dim,
201  aligned_embed_dim,
202  num_heads,
203  num_kv_heads,
204  head_dim,
205  aligned_head_dim,
206  cache_capacity,
207  eps,
208  qkv_scratch);
209  }
210 
211  if (rope_cos && rope_sin) {
213  k_ptr,
214  rope_cos,
215  rope_sin,
216  num_heads,
217  num_kv_heads,
218  tokens,
219  head_dim,
220  aligned_head_dim,
221  start_pos,
222  tokens,
223  cache_capacity);
224  }
225 
226  if (start_pos == 0) {
228  k_ptr,
229  v_ptr,
230  attn_out,
231  num_heads,
232  num_kv_heads,
233  tokens,
234  head_dim,
235  aligned_head_dim,
236  cache_capacity);
237  } else {
238  const float scale = 1.0f / sqrtf((float)head_dim);
239  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
240  const size_t kv_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
241 
242  for (int h = 0; h < num_heads; ++h) {
243  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
244  const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
245  const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
246 
247  for (int i = 0; i < tokens; ++i) {
248  const float *q_vec = q + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
249  float *out_vec = attn_out + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
250  attention_flash_decode(out_vec,
251  q_vec,
252  k_head,
253  v_head,
254  1,
255  start_pos + i + 1,
256  1,
257  aligned_head_dim,
258  scale);
259  }
260  }
261  }
262 
263  if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
264  return;
265  }
266 
267  /* Quantized activations path: Q8_0 attn_out + Q8_0 weights. */
268  {
269  uint8_t *attn_q8 = (uint8_t *)q;
271  attn_q8,
272  tokens,
273  num_heads,
274  aligned_head_dim);
276  wo,
277  bo,
278  output,
279  tokens,
280  aligned_embed_dim,
281  num_heads,
282  aligned_head_dim);
283  }
284 
285  if (residual) {
287  output,
288  output,
289  tokens,
290  aligned_embed_dim);
291  }
292 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
#define QK8_0
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)

References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q8_0, ck_residual_add_token_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q8_0_q8_0(), QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().

◆ mega_fused_attention_prefill_q8_0_scratch_size()

size_t mega_fused_attention_prefill_q8_0_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Get scratch buffer size for mega_fused_attention_prefill_q8_0.

Definition at line 84 of file mega_fused_attention_prefill_q8_0.c.

88 {
89  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
90  return 0;
91  }
92 
93  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
94  (size_t)aligned_head_dim * sizeof(float);
95  const size_t attn_bytes = q_bytes;
96  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
97  const size_t qkv_scratch = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
98 
99  return align_up_size(q_bytes, 64) +
100  align_up_size(attn_bytes, 64) +
101  align_up_size(proj_bytes, 64) +
102  align_up_size(qkv_scratch, 64);
103 }

References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().

◆ out_proj_head_major_q8_0_q8_0()

static void out_proj_head_major_q8_0_q8_0 ( const uint8_t *  attn_q8,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 49 of file mega_fused_attention_prefill_q8_0.c.

57 {
58  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
59  (size_t)aligned_head_dim);
60  const int blocks_per_head = aligned_head_dim / QK8_0;
61  const int blocks_per_row = aligned_embed_dim / QK8_0;
62  const block_q8_0 *weights = (const block_q8_0 *)wo;
63 
64  for (int t = 0; t < tokens; ++t) {
65  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
66  for (int n = 0; n < aligned_embed_dim; ++n) {
67  float sum = bias ? bias[n] : 0.0f;
68  const block_q8_0 *w_row = weights + (size_t)n * (size_t)blocks_per_row;
69 
70  for (int h = 0; h < num_heads; ++h) {
71  const uint8_t *a_row = attn_q8 +
72  ((size_t)h * (size_t)tokens + (size_t)t) *
73  q8_row_bytes;
74  const block_q8_0 *w_head = w_row + (size_t)h * (size_t)blocks_per_head;
75  float partial = 0.0f;
76  vec_dot_q8_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
77  sum += partial;
78  }
79  out_row[n] = sum;
80  }
81  }
82 }
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.

References CK_DT_Q8_0, ck_dtype_row_bytes(), QK8_0, and vec_dot_q8_0_q8_0().

Referenced by mega_fused_attention_prefill_q8_0().

◆ quantize_attn_out_head_major_q8_0()

static void quantize_attn_out_head_major_q8_0 ( const float *  attn_out,
uint8_t *  dst,
int  tokens,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 29 of file mega_fused_attention_prefill_q8_0.c.

34 {
35  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
36  (size_t)aligned_head_dim);
37  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
38  for (int h = 0; h < num_heads; ++h) {
39  const float *head = attn_out + (size_t)h * head_stride;
40  for (int t = 0; t < tokens; ++t) {
41  const float *row = head + (size_t)t * (size_t)aligned_head_dim;
42  uint8_t *out = dst + ((size_t)h * (size_t)tokens + (size_t)t) *
43  q8_row_bytes;
44  quantize_row_q8_0(row, out, aligned_head_dim);
45  }
46  }
47 }
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References CK_DT_Q8_0, ck_dtype_row_bytes(), and quantize_row_q8_0().

Referenced by mega_fused_attention_prefill_q8_0().