67 extern void rmsnorm_forward(
const float *input,
const float *gamma,
float *output,
68 float *rstd,
int T,
int D,
int AD,
float eps);
70 const float *q_token,
const float *k_cache,
const float *v_cache,
71 float *out_token,
int num_heads,
int num_kv_heads,
int kv_tokens,
72 int cache_capacity,
int head_dim,
int aligned_head_dim);
94 const int blocks_per_row = K /
QK5_0;
100 for (
int row = 0; row < M; row++) {
103 out[row] = dot + (bias ? bias[row] : 0.0f);
118 const int blocks_per_row = K /
QK8_0;
124 for (
int row = 0; row < M; row++) {
127 out[row] = dot + (bias ? bias[row] : 0.0f);
138 const float *rope_cos,
139 const float *rope_sin,
145 const int D = AD / 2;
146 const float *cos_row = &rope_cos[pos * D];
147 const float *sin_row = &rope_sin[pos * D];
150 for (
int h = 0; h < H; h++) {
151 float *q_head = &q[h * AD];
152 for (
int d = 0; d < D; d++) {
153 float q0 = q_head[d];
154 float q1 = q_head[d + D];
155 q_head[d] = q0 * cos_row[d] - q1 * sin_row[d];
156 q_head[d + D] = q0 * sin_row[d] + q1 * cos_row[d];
161 for (
int kv = 0; kv < KV; kv++) {
162 float *k_head = &k[kv * AD];
163 for (
int d = 0; d < D; d++) {
164 float k0 = k_head[d];
165 float k1 = k_head[d + D];
166 k_head[d] = k0 * cos_row[d] - k1 * sin_row[d];
167 k_head[d + D] = k0 * sin_row[d] + k1 * cos_row[d];
185 int max_input_dim = (AE > H * AD) ? AE : H * AD;
186 int q8_blocks = (max_input_dim +
QK8_0 - 1) /
QK8_0;
187 return (
int)(
sizeof(float) * (AE + AE + H * AD + 2 * KV * AD + H * AD)
225 const float *residual,
230 const float *ln_gamma,
237 const float *rope_cos,
238 const float *rope_sin,
241 int aligned_embed_dim,
245 int aligned_head_dim,
250 const int H = num_heads;
251 const int KV = num_kv_heads;
252 const int AD = head_dim;
253 const int AE = aligned_embed_dim;
257 float *scratch_ptr = (
float *)scratch;
259 float *rmsnorm_out = scratch_ptr;
262 float *rstd_scratch = scratch_ptr;
265 float *q = scratch_ptr;
266 scratch_ptr += H * AD;
268 float *k = scratch_ptr;
269 scratch_ptr += KV * AD;
271 float *v = scratch_ptr;
272 scratch_ptr += KV * AD;
274 float *attn_out = scratch_ptr;
275 scratch_ptr += H * AD;
279 const int q_size = H * AD;
280 const int k_size = KV * AD;
281 const int v_size = KV * AD;
288 rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
307 const size_t kv_stride = (size_t)cache_capacity * AD;
308 for (
int kv = 0; kv < KV; kv++) {
309 float *k_cache = &kv_cache_k[kv * kv_stride];
310 float *v_cache = &kv_cache_v[kv * kv_stride];
311 const float *k_src = &k[kv * AD];
312 const float *v_src = &v[kv * AD];
313 const int offset = pos * AD;
314 for (
int d = 0; d < AD; d++) {
315 k_cache[offset + d] = k_src[d];
316 v_cache[offset + d] = v_src[d];
326 q, kv_cache_k, kv_cache_v,
327 attn_out, H, KV, pos + 1, cache_capacity, AD, aligned_head_dim);
344 const int blocks_per_row = (H * AD) /
QK5_0;
346 for (
int e = 0; e < AE; e++) {
349 output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
370 const float *residual,
375 const float *ln_gamma,
382 const float *rope_cos,
383 const float *rope_sin,
386 int aligned_embed_dim,
390 int aligned_head_dim,
397 const int H = num_heads;
398 const int KV = num_kv_heads;
399 const int AD = head_dim;
400 const int AE = aligned_embed_dim;
404 const int heads_per_thread = (H + nth - 1) / nth;
405 const int h_start = ith * heads_per_thread;
406 const int h_end = (h_start + heads_per_thread < H) ? h_start + heads_per_thread : H;
407 const int my_heads = h_end - h_start;
409 if (h_start >= H)
return;
412 float *scratch_ptr = (
float *)scratch;
414 float *rmsnorm_out = scratch_ptr;
417 float *rstd_scratch = scratch_ptr;
420 float *q = scratch_ptr;
421 scratch_ptr += H * AD;
423 float *k = scratch_ptr;
424 scratch_ptr += KV * AD;
426 float *v = scratch_ptr;
427 scratch_ptr += KV * AD;
429 float *attn_out = scratch_ptr;
430 scratch_ptr += H * AD;
440 rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
449 const size_t kv_stride = (size_t)cache_capacity * AD;
450 for (
int kv_idx = 0; kv_idx < KV; kv_idx++) {
451 float *k_cache = &kv_cache_k[kv_idx * kv_stride];
452 float *v_cache = &kv_cache_v[kv_idx * kv_stride];
453 const int offset = pos * AD;
454 for (
int d = 0; d < AD; d++) {
455 k_cache[offset + d] = k[kv_idx * AD + d];
456 v_cache[offset + d] = v[kv_idx * AD + d];
476 kv_cache_k, kv_cache_v,
477 &attn_out[h_start * AD],
480 pos + 1, cache_capacity, AD, aligned_head_dim);
496 const int blocks_per_row = (H * AD) /
QK5_0;
498 for (
int e = 0; e < AE; e++) {
501 output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
Quantization block structures for weight-only quantization.
void mega_fused_attention_decode_q5_0_parallel_simd(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth)
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
void mega_fused_attention_decode_q5_0(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch)
Serial mega-fused attention decode kernel.
static void gemv_q5_0_from_fp32(float *out, const void *W_q5_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
static void apply_rope_inline(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int H, int KV, int AD)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd, int T, int D, int AD, float eps)
static void gemv_q8_0_from_fp32(float *out, const void *W_q8_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD)
Calculate scratch buffer size needed for the kernel.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.