27 #include <immintrin.h>
32 return (value + align - 1) & ~(align - 1);
44 (
size_t)aligned_head_dim);
45 const size_t head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
46 for (
int h = 0; h < num_heads; ++h) {
47 const float *head = attn_out + (size_t)h * head_stride;
48 for (
int t = 0; t < tokens; ++t) {
49 const float *row = head + (size_t)t * (
size_t)aligned_head_dim;
50 uint8_t *out = dst + ((size_t)h * (
size_t)tokens + (size_t)t) *
62 int aligned_embed_dim,
66 #define OUTPROJ_TILE_N 8
68 (
size_t)aligned_head_dim);
69 const int blocks_per_head = aligned_head_dim /
QK5_0;
70 const int blocks_per_row = aligned_embed_dim /
QK5_0;
73 for (
int t = 0; t < tokens; ++t) {
74 float *out_row = output + (size_t)t * (
size_t)aligned_embed_dim;
78 : (aligned_embed_dim - n);
80 for (
int i = 0; i < tile; ++i) {
81 sum[i] = bias ? bias[n + i] : 0.0f;
84 for (
int h = 0; h < num_heads; ++h) {
85 const uint8_t *a_row = attn_q8 +
86 ((size_t)h * (
size_t)tokens + (size_t)t) *
89 (size_t)n * (
size_t)blocks_per_row +
90 (size_t)h * (
size_t)blocks_per_head;
91 for (
int i = 0; i < tile; ++i) {
93 (size_t)i * (
size_t)blocks_per_row;
100 for (
int i = 0; i < tile; ++i) {
101 out_row[n + i] = sum[i];
105 #undef OUTPROJ_TILE_N
113 int aligned_embed_dim,
115 int aligned_head_dim)
117 #define OUTPROJ_TILE_N 8
119 (
size_t)aligned_head_dim);
120 const int blocks_per_head = aligned_head_dim /
QK8_0;
121 const int blocks_per_row = aligned_embed_dim /
QK8_0;
124 for (
int t = 0; t < tokens; ++t) {
125 float *out_row = output + (size_t)t * (
size_t)aligned_embed_dim;
129 : (aligned_embed_dim - n);
131 for (
int i = 0; i < tile; ++i) {
132 sum[i] = bias ? bias[n + i] : 0.0f;
135 for (
int h = 0; h < num_heads; ++h) {
136 const uint8_t *a_row = attn_q8 +
137 ((size_t)h * (
size_t)tokens + (size_t)t) *
140 (size_t)n * (
size_t)blocks_per_row +
141 (size_t)h * (
size_t)blocks_per_head;
142 for (
int i = 0; i < tile; ++i) {
144 (size_t)i * (
size_t)blocks_per_row;
145 float partial = 0.0f;
151 for (
int i = 0; i < tile; ++i) {
152 out_row[n + i] = sum[i];
156 #undef OUTPROJ_TILE_N
160 int aligned_embed_dim,
162 int aligned_head_dim,
163 int aligned_intermediate_dim)
165 if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 ||
166 aligned_head_dim <= 0 || aligned_intermediate_dim <= 0) {
171 (
size_t)aligned_head_dim);
172 const size_t attn_q8_bytes = (size_t)num_heads * (
size_t)tokens * q8_row_bytes;
173 const size_t h1_bytes = (size_t)tokens * (
size_t)aligned_embed_dim *
sizeof(float);
174 const size_t ln2_bytes = h1_bytes;
176 aligned_embed_dim, aligned_intermediate_dim);
186 const float *attn_out,
187 const float *residual,
188 const float *ln2_gamma,
189 const void *wo,
const float *bo,
CKDataType wo_dt,
190 const void *w1,
const float *b1,
CKDataType w1_dt,
191 const void *w2,
const float *b2,
CKDataType w2_dt,
194 int aligned_embed_dim,
196 int aligned_head_dim,
197 int intermediate_dim,
198 int aligned_intermediate_dim,
202 if (!output || !attn_out || !residual || !ln2_gamma ||
203 !wo || !w1 || !w2 || !scratch) {
206 if (tokens <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
207 num_heads <= 0 || aligned_head_dim <= 0 ||
208 intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
211 if (aligned_embed_dim < embed_dim || aligned_head_dim <= 0 ||
212 aligned_intermediate_dim < intermediate_dim) {
215 if (aligned_embed_dim != num_heads * aligned_head_dim) {
218 if ((aligned_embed_dim % 32) != 0 || (aligned_head_dim % 32) != 0) {
221 if ((aligned_intermediate_dim %
QK_K) != 0) {
235 (
size_t)aligned_head_dim);
236 const size_t attn_q8_bytes = (size_t)num_heads * (
size_t)tokens * q8_row_bytes;
237 const size_t h1_bytes = (size_t)tokens * (
size_t)aligned_embed_dim *
sizeof(float);
238 const size_t ln2_bytes = h1_bytes;
240 uint8_t *scratch_bytes = (uint8_t *)scratch;
241 uint8_t *attn_q8 = scratch_bytes;
243 float *h1 = (
float *)scratch_bytes;
245 float *ln2_out = (
float *)scratch_bytes;
247 void *mlp_scratch = (
void *)scratch_bytes;
275 for (
int t = 0; t < tokens; ++t) {
276 const float *res_row = residual + (size_t)t * (
size_t)aligned_embed_dim;
277 float *h1_row = h1 + (size_t)t * (
size_t)aligned_embed_dim;
302 aligned_intermediate_dim,
305 for (
int t = 0; t < tokens; ++t) {
306 const float *h1_row = h1 + (size_t)t * (
size_t)aligned_embed_dim;
307 float *out_row = output + (size_t)t * (
size_t)aligned_embed_dim;
CKDataType
Supported data types in C-Kernel-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
void add_inplace_f32(float *a, const float *b, size_t n)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
Quantization block structures for weight-only quantization.
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.