← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_outproj_mlp_prefill.c
Go to the documentation of this file.
1 /**
2  * @file mega_fused_outproj_mlp_prefill.c
3  * @brief Mega-fused post-attention block for prefill
4  *
5  * OutProj → Residual → RMSNorm2 → MLP → Residual
6  *
7  * Plan summary:
8  * 1) Quantize head-major attn_out to Q8_0
9  * 2) Out-proj with Q5_0/Q8_0 weights → h1 (post-attn) in scratch
10  * 3) Add residual (input) into h1
11  * 4) RMSNorm2(h1) → ln2_out (scratch)
12  * 5) Fused MLP (quant W1/W2) → output
13  * 6) Add h1 residual into output
14  *
15  * Goal: avoid DRAM writes between attention out-proj and MLP output.
16  * All intermediates live in scratch buffers from the bump allocator.
17  */
18 
19 #include "ckernel_engine.h"
20 #include "ckernel_quant.h"
21 
22 #include <math.h>
23 #include <stddef.h>
24 #include <stdint.h>
25 
26 #if defined(__AVX__)
27 #include <immintrin.h>
28 #endif
29 
30 static size_t align_up_size(size_t value, size_t align)
31 {
32  return (value + align - 1) & ~(align - 1);
33 }
34 
35 /* Note: add_inplace_f32 is declared in ckernel_engine.h and defined elsewhere */
36 
37 static void quantize_attn_out_head_major_q8_0(const float *attn_out,
38  uint8_t *dst,
39  int tokens,
40  int num_heads,
41  int aligned_head_dim)
42 {
43  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
44  (size_t)aligned_head_dim);
45  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
46  for (int h = 0; h < num_heads; ++h) {
47  const float *head = attn_out + (size_t)h * head_stride;
48  for (int t = 0; t < tokens; ++t) {
49  const float *row = head + (size_t)t * (size_t)aligned_head_dim;
50  uint8_t *out = dst + ((size_t)h * (size_t)tokens + (size_t)t) *
51  q8_row_bytes;
52  quantize_row_q8_0(row, out, aligned_head_dim);
53  }
54  }
55 }
56 
57 static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8,
58  const void *wo,
59  const float *bias,
60  float *output,
61  int tokens,
62  int aligned_embed_dim,
63  int num_heads,
64  int aligned_head_dim)
65 {
66 #define OUTPROJ_TILE_N 8
67  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
68  (size_t)aligned_head_dim);
69  const int blocks_per_head = aligned_head_dim / QK5_0;
70  const int blocks_per_row = aligned_embed_dim / QK5_0;
71  const block_q5_0 *weights = (const block_q5_0 *)wo;
72 
73  for (int t = 0; t < tokens; ++t) {
74  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
75  for (int n = 0; n < aligned_embed_dim; n += OUTPROJ_TILE_N) {
76  const int tile = (n + OUTPROJ_TILE_N <= aligned_embed_dim)
78  : (aligned_embed_dim - n);
79  float sum[OUTPROJ_TILE_N];
80  for (int i = 0; i < tile; ++i) {
81  sum[i] = bias ? bias[n + i] : 0.0f;
82  }
83 
84  for (int h = 0; h < num_heads; ++h) {
85  const uint8_t *a_row = attn_q8 +
86  ((size_t)h * (size_t)tokens + (size_t)t) *
87  q8_row_bytes;
88  const block_q5_0 *w_row_base = weights +
89  (size_t)n * (size_t)blocks_per_row +
90  (size_t)h * (size_t)blocks_per_head;
91  for (int i = 0; i < tile; ++i) {
92  const block_q5_0 *w_head = w_row_base +
93  (size_t)i * (size_t)blocks_per_row;
94  float partial = 0.0f;
95  vec_dot_q5_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
96  sum[i] += partial;
97  }
98  }
99 
100  for (int i = 0; i < tile; ++i) {
101  out_row[n + i] = sum[i];
102  }
103  }
104  }
105 #undef OUTPROJ_TILE_N
106 }
107 
108 static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8,
109  const void *wo,
110  const float *bias,
111  float *output,
112  int tokens,
113  int aligned_embed_dim,
114  int num_heads,
115  int aligned_head_dim)
116 {
117 #define OUTPROJ_TILE_N 8
118  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
119  (size_t)aligned_head_dim);
120  const int blocks_per_head = aligned_head_dim / QK8_0;
121  const int blocks_per_row = aligned_embed_dim / QK8_0;
122  const block_q8_0 *weights = (const block_q8_0 *)wo;
123 
124  for (int t = 0; t < tokens; ++t) {
125  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
126  for (int n = 0; n < aligned_embed_dim; n += OUTPROJ_TILE_N) {
127  const int tile = (n + OUTPROJ_TILE_N <= aligned_embed_dim)
129  : (aligned_embed_dim - n);
130  float sum[OUTPROJ_TILE_N];
131  for (int i = 0; i < tile; ++i) {
132  sum[i] = bias ? bias[n + i] : 0.0f;
133  }
134 
135  for (int h = 0; h < num_heads; ++h) {
136  const uint8_t *a_row = attn_q8 +
137  ((size_t)h * (size_t)tokens + (size_t)t) *
138  q8_row_bytes;
139  const block_q8_0 *w_row_base = weights +
140  (size_t)n * (size_t)blocks_per_row +
141  (size_t)h * (size_t)blocks_per_head;
142  for (int i = 0; i < tile; ++i) {
143  const block_q8_0 *w_head = w_row_base +
144  (size_t)i * (size_t)blocks_per_row;
145  float partial = 0.0f;
146  vec_dot_q8_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
147  sum[i] += partial;
148  }
149  }
150 
151  for (int i = 0; i < tile; ++i) {
152  out_row[n + i] = sum[i];
153  }
154  }
155  }
156 #undef OUTPROJ_TILE_N
157 }
158 
160  int aligned_embed_dim,
161  int num_heads,
162  int aligned_head_dim,
163  int aligned_intermediate_dim)
164 {
165  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 ||
166  aligned_head_dim <= 0 || aligned_intermediate_dim <= 0) {
167  return 0;
168  }
169 
170  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
171  (size_t)aligned_head_dim);
172  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
173  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
174  const size_t ln2_bytes = h1_bytes;
175  const size_t mlp_scratch = fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(
176  aligned_embed_dim, aligned_intermediate_dim);
177 
178  return align_up_size(attn_q8_bytes, 64) +
179  align_up_size(h1_bytes, 64) +
180  align_up_size(ln2_bytes, 64) +
181  align_up_size(mlp_scratch, 64);
182 }
183 
185  float *output,
186  const float *attn_out,
187  const float *residual,
188  const float *ln2_gamma,
189  const void *wo, const float *bo, CKDataType wo_dt,
190  const void *w1, const float *b1, CKDataType w1_dt,
191  const void *w2, const float *b2, CKDataType w2_dt,
192  int tokens,
193  int embed_dim,
194  int aligned_embed_dim,
195  int num_heads,
196  int aligned_head_dim,
197  int intermediate_dim,
198  int aligned_intermediate_dim,
199  float eps,
200  void *scratch)
201 {
202  if (!output || !attn_out || !residual || !ln2_gamma ||
203  !wo || !w1 || !w2 || !scratch) {
204  return;
205  }
206  if (tokens <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
207  num_heads <= 0 || aligned_head_dim <= 0 ||
208  intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
209  return;
210  }
211  if (aligned_embed_dim < embed_dim || aligned_head_dim <= 0 ||
212  aligned_intermediate_dim < intermediate_dim) {
213  return;
214  }
215  if (aligned_embed_dim != num_heads * aligned_head_dim) {
216  return;
217  }
218  if ((aligned_embed_dim % 32) != 0 || (aligned_head_dim % 32) != 0) {
219  return;
220  }
221  if ((aligned_intermediate_dim % QK_K) != 0) {
222  return;
223  }
224  if (wo_dt != CK_DT_Q5_0 && wo_dt != CK_DT_Q8_0) {
225  return;
226  }
227  if (w1_dt != CK_DT_Q5_0 && w1_dt != CK_DT_Q8_0) {
228  return;
229  }
230  if (w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K) {
231  return;
232  }
233 
234  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
235  (size_t)aligned_head_dim);
236  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
237  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
238  const size_t ln2_bytes = h1_bytes;
239 
240  uint8_t *scratch_bytes = (uint8_t *)scratch;
241  uint8_t *attn_q8 = scratch_bytes;
242  scratch_bytes += align_up_size(attn_q8_bytes, 64);
243  float *h1 = (float *)scratch_bytes;
244  scratch_bytes += align_up_size(h1_bytes, 64);
245  float *ln2_out = (float *)scratch_bytes;
246  scratch_bytes += align_up_size(ln2_bytes, 64);
247  void *mlp_scratch = (void *)scratch_bytes;
248 
250  attn_q8,
251  tokens,
252  num_heads,
253  aligned_head_dim);
254 
255  if (wo_dt == CK_DT_Q8_0) {
257  wo,
258  bo,
259  h1,
260  tokens,
261  aligned_embed_dim,
262  num_heads,
263  aligned_head_dim);
264  } else {
266  wo,
267  bo,
268  h1,
269  tokens,
270  aligned_embed_dim,
271  num_heads,
272  aligned_head_dim);
273  }
274 
275  for (int t = 0; t < tokens; ++t) {
276  const float *res_row = residual + (size_t)t * (size_t)aligned_embed_dim;
277  float *h1_row = h1 + (size_t)t * (size_t)aligned_embed_dim;
278  add_inplace_f32(h1_row, res_row, aligned_embed_dim);
279  }
280 
281  rmsnorm_forward(h1,
282  ln2_gamma,
283  ln2_out,
284  NULL,
285  tokens,
286  embed_dim,
287  aligned_embed_dim,
288  eps);
289 
291  w1,
292  b1,
293  w1_dt,
294  w2,
295  b2,
296  w2_dt,
297  output,
298  tokens,
299  embed_dim,
300  aligned_embed_dim,
301  intermediate_dim,
302  aligned_intermediate_dim,
303  mlp_scratch);
304 
305  for (int t = 0; t < tokens; ++t) {
306  const float *h1_row = h1 + (size_t)t * (size_t)aligned_embed_dim;
307  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
308  add_inplace_f32(out_row, h1_row, aligned_embed_dim);
309  }
310 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
void add_inplace_f32(float *a, const float *b, size_t n)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
Quantization block structures for weight-only quantization.
#define QK5_0
Definition: ckernel_quant.h:67
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
#define QK8_0
#define QK_K
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
#define OUTPROJ_TILE_N
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.