← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_prefill_q8_0.c
Go to the documentation of this file.
1 /**
2  * @file mega_fused_attention_prefill_q8_0.c
3  * @brief Mega-fused prefill attention kernel with Q8_0 out-proj
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. NO memcpy for layout - use strided access, not copies
10  * 4. API must define: inputs, outputs, workspace, and memory layouts
11  * 5. Pure computation - deterministic, no side effects
12  *
13  * RMSNorm → QKV → RoPE → Flash Attention → Q8_0 OutProj + Residual
14  * Writes K/V directly into the KV cache layout (stride = cache_capacity).
15  */
16 
17 #include "ckernel_engine.h"
18 #include "ckernel_orchestration.h"
19 #include "ckernel_quant.h"
20 
21 #include <math.h>
22 #include <string.h>
23 
24 static size_t align_up_size(size_t value, size_t align)
25 {
26  return (value + align - 1) & ~(align - 1);
27 }
28 
29 static void quantize_attn_out_head_major_q8_0(const float *attn_out,
30  uint8_t *dst,
31  int tokens,
32  int num_heads,
33  int aligned_head_dim)
34 {
35  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
36  (size_t)aligned_head_dim);
37  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
38  for (int h = 0; h < num_heads; ++h) {
39  const float *head = attn_out + (size_t)h * head_stride;
40  for (int t = 0; t < tokens; ++t) {
41  const float *row = head + (size_t)t * (size_t)aligned_head_dim;
42  uint8_t *out = dst + ((size_t)h * (size_t)tokens + (size_t)t) *
43  q8_row_bytes;
44  quantize_row_q8_0(row, out, aligned_head_dim);
45  }
46  }
47 }
48 
49 static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8,
50  const void *wo,
51  const float *bias,
52  float *output,
53  int tokens,
54  int aligned_embed_dim,
55  int num_heads,
56  int aligned_head_dim)
57 {
58  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
59  (size_t)aligned_head_dim);
60  const int blocks_per_head = aligned_head_dim / QK8_0;
61  const int blocks_per_row = aligned_embed_dim / QK8_0;
62  const block_q8_0 *weights = (const block_q8_0 *)wo;
63 
64  for (int t = 0; t < tokens; ++t) {
65  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
66  for (int n = 0; n < aligned_embed_dim; ++n) {
67  float sum = bias ? bias[n] : 0.0f;
68  const block_q8_0 *w_row = weights + (size_t)n * (size_t)blocks_per_row;
69 
70  for (int h = 0; h < num_heads; ++h) {
71  const uint8_t *a_row = attn_q8 +
72  ((size_t)h * (size_t)tokens + (size_t)t) *
73  q8_row_bytes;
74  const block_q8_0 *w_head = w_row + (size_t)h * (size_t)blocks_per_head;
75  float partial = 0.0f;
76  vec_dot_q8_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
77  sum += partial;
78  }
79  out_row[n] = sum;
80  }
81  }
82 }
83 
85  int aligned_embed_dim,
86  int num_heads,
87  int aligned_head_dim)
88 {
89  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
90  return 0;
91  }
92 
93  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
94  (size_t)aligned_head_dim * sizeof(float);
95  const size_t attn_bytes = q_bytes;
96  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
97  const size_t qkv_scratch = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
98 
99  return align_up_size(q_bytes, 64) +
100  align_up_size(attn_bytes, 64) +
101  align_up_size(proj_bytes, 64) +
102  align_up_size(qkv_scratch, 64);
103 }
104 
106  float *output,
107  const float *input,
108  const float *residual,
109  const float *ln1_gamma,
110  const void *wq, const float *bq, CKDataType wq_dt,
111  const void *wk, const float *bk, CKDataType wk_dt,
112  const void *wv, const float *bv, CKDataType wv_dt,
113  const void *wo, const float *bo, CKDataType wo_dt,
114  float *kv_cache_k,
115  float *kv_cache_v,
116  const float *rope_cos,
117  const float *rope_sin,
118  int start_pos,
119  int tokens,
120  int cache_capacity,
121  int embed_dim,
122  int aligned_embed_dim,
123  int num_heads,
124  int num_kv_heads,
125  int head_dim,
126  int aligned_head_dim,
127  float eps,
128  void *scratch)
129 {
130  if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
131  !kv_cache_k || !kv_cache_v || !scratch) {
132  return;
133  }
134  if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
135  head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
136  return;
137  }
138  if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
139  return;
140  }
141  if (start_pos < 0 || start_pos + tokens > cache_capacity) {
142  return;
143  }
144  if (wo_dt != CK_DT_Q8_0) {
145  return;
146  }
147  if ((aligned_head_dim % QK8_0) != 0 || (aligned_embed_dim % QK8_0) != 0) {
148  return;
149  }
150 
151  const size_t q_bytes = (size_t)num_heads * (size_t)tokens *
152  (size_t)aligned_head_dim * sizeof(float);
153  const size_t attn_bytes = q_bytes;
154  const size_t proj_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
155  const size_t qkv_scratch_bytes = fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(aligned_embed_dim);
156 
157  uint8_t *scratch_bytes = (uint8_t *)scratch;
158  float *q = (float *)scratch_bytes;
159  scratch_bytes += align_up_size(q_bytes, 64);
160  float *attn_out = (float *)scratch_bytes;
161  scratch_bytes += align_up_size(attn_bytes, 64);
162  float *proj_scratch = (float *)scratch_bytes;
163  scratch_bytes += align_up_size(proj_bytes, 64);
164  void *qkv_scratch = (void *)scratch_bytes;
165  (void)qkv_scratch_bytes;
166  (void)proj_scratch;
167 
168  float *k_ptr = kv_cache_k + (size_t)start_pos * (size_t)aligned_head_dim;
169  float *v_ptr = kv_cache_v + (size_t)start_pos * (size_t)aligned_head_dim;
170 
171  if (wq_dt == CK_DT_FP32 && wk_dt == CK_DT_FP32 && wv_dt == CK_DT_FP32) {
173  ln1_gamma,
174  (const float *)wq, bq,
175  (const float *)wk, bk,
176  (const float *)wv, bv,
177  q,
178  k_ptr,
179  v_ptr,
180  tokens,
181  embed_dim,
182  aligned_embed_dim,
183  num_heads,
184  num_kv_heads,
185  head_dim,
186  aligned_head_dim,
187  cache_capacity,
188  eps,
189  qkv_scratch);
190  } else {
192  ln1_gamma,
193  wq, bq, wq_dt,
194  wk, bk, wk_dt,
195  wv, bv, wv_dt,
196  q,
197  k_ptr,
198  v_ptr,
199  tokens,
200  embed_dim,
201  aligned_embed_dim,
202  num_heads,
203  num_kv_heads,
204  head_dim,
205  aligned_head_dim,
206  cache_capacity,
207  eps,
208  qkv_scratch);
209  }
210 
211  if (rope_cos && rope_sin) {
213  k_ptr,
214  rope_cos,
215  rope_sin,
216  num_heads,
217  num_kv_heads,
218  tokens,
219  head_dim,
220  aligned_head_dim,
221  start_pos,
222  tokens,
223  cache_capacity);
224  }
225 
226  if (start_pos == 0) {
228  k_ptr,
229  v_ptr,
230  attn_out,
231  num_heads,
232  num_kv_heads,
233  tokens,
234  head_dim,
235  aligned_head_dim,
236  cache_capacity);
237  } else {
238  const float scale = 1.0f / sqrtf((float)head_dim);
239  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
240  const size_t kv_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
241 
242  for (int h = 0; h < num_heads; ++h) {
243  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
244  const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
245  const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
246 
247  for (int i = 0; i < tokens; ++i) {
248  const float *q_vec = q + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
249  float *out_vec = attn_out + (size_t)h * q_head_stride + (size_t)i * (size_t)aligned_head_dim;
250  attention_flash_decode(out_vec,
251  q_vec,
252  k_head,
253  v_head,
254  1,
255  start_pos + i + 1,
256  1,
257  aligned_head_dim,
258  scale);
259  }
260  }
261  }
262 
263  if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
264  return;
265  }
266 
267  /* Quantized activations path: Q8_0 attn_out + Q8_0 weights. */
268  {
269  uint8_t *attn_q8 = (uint8_t *)q;
271  attn_q8,
272  tokens,
273  num_heads,
274  aligned_head_dim);
276  wo,
277  bo,
278  output,
279  tokens,
280  aligned_embed_dim,
281  num_heads,
282  aligned_head_dim);
283  }
284 
285  if (residual) {
287  output,
288  output,
289  tokens,
290  aligned_embed_dim);
291  }
292 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
Quantization block structures for weight-only quantization.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
#define QK8_0
size_t mega_fused_attention_prefill_q8_0_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
void mega_fused_attention_prefill_q8_0(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused prefill attention kernel (Q8_0 out-proj)