← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_outproj_mlp_prefill.c File Reference

Mega-fused post-attention block for prefill. More...

#include "ckernel_engine.h"
#include "ckernel_quant.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>

Go to the source code of this file.

Macros

#define OUTPROJ_TILE_N   8
 
#define OUTPROJ_TILE_N   8
 

Functions

static size_t align_up_size (size_t value, size_t align)
 
void mega_fused_outproj_mlp_prefill (float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
 Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill. More...
 
size_t mega_fused_outproj_mlp_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
 Get scratch buffer size for mega_fused_outproj_mlp_prefill. More...
 
static void out_proj_head_major_q5_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void out_proj_head_major_q8_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void quantize_attn_out_head_major_q8_0 (const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
 

Detailed Description

Mega-fused post-attention block for prefill.

OutProj → Residual → RMSNorm2 → MLP → Residual

Plan summary: 1) Quantize head-major attn_out to Q8_0 2) Out-proj with Q5_0/Q8_0 weights → h1 (post-attn) in scratch 3) Add residual (input) into h1 4) RMSNorm2(h1) → ln2_out (scratch) 5) Fused MLP (quant W1/W2) → output 6) Add h1 residual into output

Goal: avoid DRAM writes between attention out-proj and MLP output. All intermediates live in scratch buffers from the bump allocator.

Definition in file mega_fused_outproj_mlp_prefill.c.

Macro Definition Documentation

◆ OUTPROJ_TILE_N [1/2]

#define OUTPROJ_TILE_N   8

◆ OUTPROJ_TILE_N [2/2]

#define OUTPROJ_TILE_N   8

Function Documentation

◆ align_up_size()

static size_t align_up_size ( size_t  value,
size_t  align 
)
static

Definition at line 30 of file mega_fused_outproj_mlp_prefill.c.

31 {
32  return (value + align - 1) & ~(align - 1);
33 }

Referenced by mega_fused_outproj_mlp_prefill(), and mega_fused_outproj_mlp_prefill_scratch_size().

◆ mega_fused_outproj_mlp_prefill()

void mega_fused_outproj_mlp_prefill ( float *  output,
const float *  attn_out,
const float *  residual,
const float *  ln2_gamma,
const void *  wo,
const float *  bo,
CKDataType  wo_dt,
const void *  w1,
const float *  b1,
CKDataType  w1_dt,
const void *  w2,
const float *  b2,
CKDataType  w2_dt,
int  tokens,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
int  intermediate_dim,
int  aligned_intermediate_dim,
float  eps,
void *  scratch 
)

Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.

Uses head-major attention output and quantized out-proj (Q5_0/Q8_0 weights).

Definition at line 184 of file mega_fused_outproj_mlp_prefill.c.

201 {
202  if (!output || !attn_out || !residual || !ln2_gamma ||
203  !wo || !w1 || !w2 || !scratch) {
204  return;
205  }
206  if (tokens <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
207  num_heads <= 0 || aligned_head_dim <= 0 ||
208  intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
209  return;
210  }
211  if (aligned_embed_dim < embed_dim || aligned_head_dim <= 0 ||
212  aligned_intermediate_dim < intermediate_dim) {
213  return;
214  }
215  if (aligned_embed_dim != num_heads * aligned_head_dim) {
216  return;
217  }
218  if ((aligned_embed_dim % 32) != 0 || (aligned_head_dim % 32) != 0) {
219  return;
220  }
221  if ((aligned_intermediate_dim % QK_K) != 0) {
222  return;
223  }
224  if (wo_dt != CK_DT_Q5_0 && wo_dt != CK_DT_Q8_0) {
225  return;
226  }
227  if (w1_dt != CK_DT_Q5_0 && w1_dt != CK_DT_Q8_0) {
228  return;
229  }
230  if (w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K) {
231  return;
232  }
233 
234  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
235  (size_t)aligned_head_dim);
236  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
237  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
238  const size_t ln2_bytes = h1_bytes;
239 
240  uint8_t *scratch_bytes = (uint8_t *)scratch;
241  uint8_t *attn_q8 = scratch_bytes;
242  scratch_bytes += align_up_size(attn_q8_bytes, 64);
243  float *h1 = (float *)scratch_bytes;
244  scratch_bytes += align_up_size(h1_bytes, 64);
245  float *ln2_out = (float *)scratch_bytes;
246  scratch_bytes += align_up_size(ln2_bytes, 64);
247  void *mlp_scratch = (void *)scratch_bytes;
248 
250  attn_q8,
251  tokens,
252  num_heads,
253  aligned_head_dim);
254 
255  if (wo_dt == CK_DT_Q8_0) {
257  wo,
258  bo,
259  h1,
260  tokens,
261  aligned_embed_dim,
262  num_heads,
263  aligned_head_dim);
264  } else {
266  wo,
267  bo,
268  h1,
269  tokens,
270  aligned_embed_dim,
271  num_heads,
272  aligned_head_dim);
273  }
274 
275  for (int t = 0; t < tokens; ++t) {
276  const float *res_row = residual + (size_t)t * (size_t)aligned_embed_dim;
277  float *h1_row = h1 + (size_t)t * (size_t)aligned_embed_dim;
278  add_inplace_f32(h1_row, res_row, aligned_embed_dim);
279  }
280 
281  rmsnorm_forward(h1,
282  ln2_gamma,
283  ln2_out,
284  NULL,
285  tokens,
286  embed_dim,
287  aligned_embed_dim,
288  eps);
289 
291  w1,
292  b1,
293  w1_dt,
294  w2,
295  b2,
296  w2_dt,
297  output,
298  tokens,
299  embed_dim,
300  aligned_embed_dim,
301  intermediate_dim,
302  aligned_intermediate_dim,
303  mlp_scratch);
304 
305  for (int t = 0; t < tokens; ++t) {
306  const float *h1_row = h1 + (size_t)t * (size_t)aligned_embed_dim;
307  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
308  add_inplace_f32(out_row, h1_row, aligned_embed_dim);
309  }
310 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void add_inplace_f32(float *a, const float *b, size_t n)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
#define QK_K
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)

References add_inplace_f32(), align_up_size(), CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q6_K, CK_DT_Q8_0, ck_dtype_row_bytes(), fused_mlp_swiglu_prefill_w1w2_quant(), out_proj_head_major_q5_0_q8_0(), out_proj_head_major_q8_0_q8_0(), QK_K, quantize_attn_out_head_major_q8_0(), and rmsnorm_forward().

◆ mega_fused_outproj_mlp_prefill_scratch_size()

size_t mega_fused_outproj_mlp_prefill_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
int  aligned_intermediate_dim 
)

Get scratch buffer size for mega_fused_outproj_mlp_prefill.

Definition at line 159 of file mega_fused_outproj_mlp_prefill.c.

164 {
165  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 ||
166  aligned_head_dim <= 0 || aligned_intermediate_dim <= 0) {
167  return 0;
168  }
169 
170  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
171  (size_t)aligned_head_dim);
172  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
173  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
174  const size_t ln2_bytes = h1_bytes;
175  const size_t mlp_scratch = fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(
176  aligned_embed_dim, aligned_intermediate_dim);
177 
178  return align_up_size(attn_q8_bytes, 64) +
179  align_up_size(h1_bytes, 64) +
180  align_up_size(ln2_bytes, 64) +
181  align_up_size(mlp_scratch, 64);
182 }
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.

References align_up_size(), CK_DT_Q8_0, ck_dtype_row_bytes(), and fused_mlp_swiglu_prefill_w1w2_quant_scratch_size().

Referenced by ck_test_outproj_mlp_fused_q5_0().

◆ out_proj_head_major_q5_0_q8_0()

static void out_proj_head_major_q5_0_q8_0 ( const uint8_t *  attn_q8,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 57 of file mega_fused_outproj_mlp_prefill.c.

65 {
66 #define OUTPROJ_TILE_N 8
67  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
68  (size_t)aligned_head_dim);
69  const int blocks_per_head = aligned_head_dim / QK5_0;
70  const int blocks_per_row = aligned_embed_dim / QK5_0;
71  const block_q5_0 *weights = (const block_q5_0 *)wo;
72 
73  for (int t = 0; t < tokens; ++t) {
74  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
75  for (int n = 0; n < aligned_embed_dim; n += OUTPROJ_TILE_N) {
76  const int tile = (n + OUTPROJ_TILE_N <= aligned_embed_dim)
78  : (aligned_embed_dim - n);
79  float sum[OUTPROJ_TILE_N];
80  for (int i = 0; i < tile; ++i) {
81  sum[i] = bias ? bias[n + i] : 0.0f;
82  }
83 
84  for (int h = 0; h < num_heads; ++h) {
85  const uint8_t *a_row = attn_q8 +
86  ((size_t)h * (size_t)tokens + (size_t)t) *
87  q8_row_bytes;
88  const block_q5_0 *w_row_base = weights +
89  (size_t)n * (size_t)blocks_per_row +
90  (size_t)h * (size_t)blocks_per_head;
91  for (int i = 0; i < tile; ++i) {
92  const block_q5_0 *w_head = w_row_base +
93  (size_t)i * (size_t)blocks_per_row;
94  float partial = 0.0f;
95  vec_dot_q5_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
96  sum[i] += partial;
97  }
98  }
99 
100  for (int i = 0; i < tile; ++i) {
101  out_row[n + i] = sum[i];
102  }
103  }
104  }
105 #undef OUTPROJ_TILE_N
106 }
#define QK5_0
Definition: ckernel_quant.h:67
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
#define OUTPROJ_TILE_N

References CK_DT_Q8_0, ck_dtype_row_bytes(), OUTPROJ_TILE_N, QK5_0, and vec_dot_q5_0_q8_0().

Referenced by mega_fused_outproj_mlp_prefill().

◆ out_proj_head_major_q8_0_q8_0()

static void out_proj_head_major_q8_0_q8_0 ( const uint8_t *  attn_q8,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 108 of file mega_fused_outproj_mlp_prefill.c.

116 {
117 #define OUTPROJ_TILE_N 8
118  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
119  (size_t)aligned_head_dim);
120  const int blocks_per_head = aligned_head_dim / QK8_0;
121  const int blocks_per_row = aligned_embed_dim / QK8_0;
122  const block_q8_0 *weights = (const block_q8_0 *)wo;
123 
124  for (int t = 0; t < tokens; ++t) {
125  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
126  for (int n = 0; n < aligned_embed_dim; n += OUTPROJ_TILE_N) {
127  const int tile = (n + OUTPROJ_TILE_N <= aligned_embed_dim)
129  : (aligned_embed_dim - n);
130  float sum[OUTPROJ_TILE_N];
131  for (int i = 0; i < tile; ++i) {
132  sum[i] = bias ? bias[n + i] : 0.0f;
133  }
134 
135  for (int h = 0; h < num_heads; ++h) {
136  const uint8_t *a_row = attn_q8 +
137  ((size_t)h * (size_t)tokens + (size_t)t) *
138  q8_row_bytes;
139  const block_q8_0 *w_row_base = weights +
140  (size_t)n * (size_t)blocks_per_row +
141  (size_t)h * (size_t)blocks_per_head;
142  for (int i = 0; i < tile; ++i) {
143  const block_q8_0 *w_head = w_row_base +
144  (size_t)i * (size_t)blocks_per_row;
145  float partial = 0.0f;
146  vec_dot_q8_0_q8_0(aligned_head_dim, &partial, w_head, a_row);
147  sum[i] += partial;
148  }
149  }
150 
151  for (int i = 0; i < tile; ++i) {
152  out_row[n + i] = sum[i];
153  }
154  }
155  }
156 #undef OUTPROJ_TILE_N
157 }
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
#define QK8_0

References CK_DT_Q8_0, ck_dtype_row_bytes(), OUTPROJ_TILE_N, QK8_0, and vec_dot_q8_0_q8_0().

Referenced by mega_fused_outproj_mlp_prefill().

◆ quantize_attn_out_head_major_q8_0()

static void quantize_attn_out_head_major_q8_0 ( const float *  attn_out,
uint8_t *  dst,
int  tokens,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 37 of file mega_fused_outproj_mlp_prefill.c.

42 {
43  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
44  (size_t)aligned_head_dim);
45  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
46  for (int h = 0; h < num_heads; ++h) {
47  const float *head = attn_out + (size_t)h * head_stride;
48  for (int t = 0; t < tokens; ++t) {
49  const float *row = head + (size_t)t * (size_t)aligned_head_dim;
50  uint8_t *out = dst + ((size_t)h * (size_t)tokens + (size_t)t) *
51  q8_row_bytes;
52  quantize_row_q8_0(row, out, aligned_head_dim);
53  }
54  }
55 }
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References CK_DT_Q8_0, ck_dtype_row_bytes(), and quantize_row_q8_0().

Referenced by mega_fused_outproj_mlp_prefill().