35 #if defined(__AVX__) || defined(__AVX512F__)
36 #include <immintrin.h>
40 #if defined(__AVX__) && !defined(__AVX512F__)
41 static inline float ck_hsum256_ps(__m256 v)
43 __m128 lo = _mm256_castps256_ps128(v);
44 __m128 hi = _mm256_extractf128_ps(v, 1);
45 __m128 sum128 = _mm_add_ps(lo, hi);
46 __m128 shuf = _mm_movehdup_ps(sum128);
47 __m128 sums = _mm_add_ps(sum128, shuf);
48 shuf = _mm_movehl_ps(shuf, sums);
49 sums = _mm_add_ss(sums, shuf);
50 return _mm_cvtss_f32(sums);
54 static inline float ck_dot_f32(
const float *a,
const float *b,
int len)
56 #if defined(__AVX512F__)
57 __m512 acc = _mm512_setzero_ps();
59 for (; i <= len - 16; i += 16) {
60 __m512 va = _mm512_loadu_ps(a + i);
61 __m512 vb = _mm512_loadu_ps(b + i);
62 acc = _mm512_fmadd_ps(va, vb, acc);
64 float sum = _mm512_reduce_add_ps(acc);
65 for (; i < len; ++i) {
69 #elif defined(__AVX__)
70 __m256 acc = _mm256_setzero_ps();
72 for (; i <= len - 8; i += 8) {
73 __m256 va = _mm256_loadu_ps(a + i);
74 __m256 vb = _mm256_loadu_ps(b + i);
75 acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
77 float sum = ck_hsum256_ps(acc);
78 for (; i < len; ++i) {
84 for (
int i = 0; i < len; ++i) {
95 #if defined(__AVX512F__)
100 #define MEGA_Q_TILE 64
101 #define MEGA_KV_TILE 64
102 #define MEGA_STACK_MAX 8192
105 #define REG_Q_ACCUM "ZMM0-ZMM11"
106 #define REG_K_TILE "ZMM12-ZMM17"
107 #define REG_V_TILE "ZMM18-ZMM23"
108 #define REG_O_ACCUM "ZMM24-ZMM27"
109 #define REG_SOFTMAX "ZMM28-ZMM29"
110 #define REG_TEMP "ZMM30-ZMM31"
117 #define MEGA_Q_TILE 32
118 #define MEGA_KV_TILE 32
119 #define MEGA_STACK_MAX 8192
122 #define REG_Q_ACCUM "YMM0-YMM7"
123 #define REG_K_TILE "YMM8-YMM11"
124 #define REG_V_TILE "YMM12-YMM15"
125 #define REG_O_ACCUM "Stack+L1"
126 #define REG_SOFTMAX "YMM0-YMM1"
127 #define REG_TEMP "YMM2-YMM3"
156 int aligned_embed_dim,
160 int aligned_head_dim,
163 if (!q_out || !k_out || !v_out || !input || !wq || !wk || !wv) {
166 if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0) {
170 float ln1_row[aligned_embed_dim];
173 #if defined(__AVX512F__)
174 __m512 sum_vec = _mm512_setzero_ps();
176 for (; i + 16 <= embed_dim; i += 16) {
177 __m512 xv = _mm512_loadu_ps(input + i);
178 sum_vec = _mm512_fmadd_ps(xv, xv, sum_vec);
180 sum_sq = _mm512_reduce_add_ps(sum_vec);
181 for (; i < embed_dim; ++i) {
182 sum_sq += input[i] * input[i];
184 #elif defined(__AVX__)
185 __m256 sum_vec = _mm256_setzero_ps();
187 for (; i + 8 <= embed_dim; i += 8) {
188 __m256 xv = _mm256_loadu_ps(input + i);
189 sum_vec = _mm256_add_ps(sum_vec, _mm256_mul_ps(xv, xv));
191 sum_sq = ck_hsum256_ps(sum_vec);
192 for (; i < embed_dim; ++i) {
193 sum_sq += input[i] * input[i];
196 for (
int i = 0; i < embed_dim; ++i) {
197 sum_sq += input[i] * input[i];
201 float rstd = 1.0f / sqrtf(sum_sq / (
float)embed_dim + eps);
203 #if defined(__AVX512F__)
205 __m512 rstd_vec = _mm512_set1_ps(rstd);
207 for (; j + 16 <= embed_dim; j += 16) {
208 __m512 xv = _mm512_loadu_ps(input + j);
209 __m512 gv = _mm512_loadu_ps(gamma + j);
210 __m512 yv = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
211 _mm512_storeu_ps(ln1_row + j, yv);
213 for (; j < embed_dim; ++j) {
214 ln1_row[j] = input[j] * rstd * gamma[j];
217 __m512 rstd_vec = _mm512_set1_ps(rstd);
219 for (; j + 16 <= embed_dim; j += 16) {
220 __m512 xv = _mm512_loadu_ps(input + j);
221 __m512 yv = _mm512_mul_ps(xv, rstd_vec);
222 _mm512_storeu_ps(ln1_row + j, yv);
224 for (; j < embed_dim; ++j) {
225 ln1_row[j] = input[j] * rstd;
228 #elif defined(__AVX__)
230 __m256 rstd_vec = _mm256_set1_ps(rstd);
232 for (; j + 8 <= embed_dim; j += 8) {
233 __m256 xv = _mm256_loadu_ps(input + j);
234 __m256 gv = _mm256_loadu_ps(gamma + j);
235 __m256 yv = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
236 _mm256_storeu_ps(ln1_row + j, yv);
238 for (; j < embed_dim; ++j) {
239 ln1_row[j] = input[j] * rstd * gamma[j];
242 __m256 rstd_vec = _mm256_set1_ps(rstd);
244 for (; j + 8 <= embed_dim; j += 8) {
245 __m256 xv = _mm256_loadu_ps(input + j);
246 __m256 yv = _mm256_mul_ps(xv, rstd_vec);
247 _mm256_storeu_ps(ln1_row + j, yv);
249 for (; j < embed_dim; ++j) {
250 ln1_row[j] = input[j] * rstd;
254 for (
int j = 0; j < embed_dim; ++j) {
255 ln1_row[j] = input[j] * rstd * (gamma ? gamma[j] : 1.0f);
259 for (
int j = embed_dim; j < aligned_embed_dim; ++j) {
263 const size_t head_w_stride = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
265 for (
int h = 0; h < num_heads; ++h) {
266 const float *wq_h = wq + (size_t)h * head_w_stride;
267 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
268 float *q_h = q_out + (size_t)h * (
size_t)aligned_head_dim;
269 for (
int d = 0; d < head_dim; ++d) {
270 const float *row = wq_h + (size_t)d * (
size_t)aligned_embed_dim;
271 float sum =
ck_dot_f32(ln1_row, row, aligned_embed_dim);
272 q_h[d] = sum + (bq_h ? bq_h[d] : 0.0f);
274 for (
int d = head_dim; d < aligned_head_dim; ++d) {
279 for (
int h = 0; h < num_kv_heads; ++h) {
280 const float *wk_h = wk + (size_t)h * head_w_stride;
281 const float *wv_h = wv + (size_t)h * head_w_stride;
282 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
283 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
284 float *k_h = k_out + (size_t)h * (
size_t)aligned_head_dim;
285 float *v_h = v_out + (size_t)h * (
size_t)aligned_head_dim;
286 for (
int d = 0; d < head_dim; ++d) {
287 const float *wk_row = wk_h + (size_t)d * (
size_t)aligned_embed_dim;
288 const float *wv_row = wv_h + (size_t)d * (
size_t)aligned_embed_dim;
289 float k_sum =
ck_dot_f32(ln1_row, wk_row, aligned_embed_dim);
290 float v_sum =
ck_dot_f32(ln1_row, wv_row, aligned_embed_dim);
291 k_h[d] = k_sum + (bk_h ? bk_h[d] : 0.0f);
292 v_h[d] = v_sum + (bv_h ? bv_h[d] : 0.0f);
294 for (
int d = head_dim; d < aligned_head_dim; ++d) {
314 const float *rope_cos,
315 const float *rope_sin,
320 int aligned_head_dim)
322 if (!q || !k || !rope_cos || !rope_sin || head_dim <= 0 || aligned_head_dim <= 0) {
325 if ((head_dim & 1) != 0) {
329 int half = head_dim / 2;
330 const float *cos_ptr = rope_cos + (size_t)pos * (
size_t)half;
331 const float *sin_ptr = rope_sin + (size_t)pos * (
size_t)half;
333 for (
int h = 0; h < num_heads; ++h) {
334 float *q_h = q + (size_t)h * (
size_t)aligned_head_dim;
336 #if defined(__AVX512F__)
337 for (; i + 16 <= half; i += 16) {
338 __m512 q0 = _mm512_loadu_ps(q_h + i);
339 __m512 q1 = _mm512_loadu_ps(q_h + i + half);
340 __m512 cos = _mm512_loadu_ps(cos_ptr + i);
341 __m512 sin = _mm512_loadu_ps(sin_ptr + i);
343 __m512 q_rot0 = _mm512_sub_ps(_mm512_mul_ps(q0, cos), _mm512_mul_ps(q1, sin));
344 __m512 q_rot1 = _mm512_add_ps(_mm512_mul_ps(q0, sin), _mm512_mul_ps(q1, cos));
346 _mm512_storeu_ps(q_h + i, q_rot0);
347 _mm512_storeu_ps(q_h + i + half, q_rot1);
349 #elif defined(__AVX__)
350 for (; i + 8 <= half; i += 8) {
351 __m256 q0 = _mm256_loadu_ps(q_h + i);
352 __m256 q1 = _mm256_loadu_ps(q_h + i + half);
353 __m256 cos = _mm256_loadu_ps(cos_ptr + i);
354 __m256 sin = _mm256_loadu_ps(sin_ptr + i);
356 __m256 q_rot0 = _mm256_sub_ps(_mm256_mul_ps(q0, cos), _mm256_mul_ps(q1, sin));
357 __m256 q_rot1 = _mm256_add_ps(_mm256_mul_ps(q0, sin), _mm256_mul_ps(q1, cos));
359 _mm256_storeu_ps(q_h + i, q_rot0);
360 _mm256_storeu_ps(q_h + i + half, q_rot1);
363 for (; i < half; ++i) {
365 float q1 = q_h[i + half];
366 float c = cos_ptr[i];
367 float s = sin_ptr[i];
368 q_h[i] = q0 * c - q1 * s;
369 q_h[i + half] = q0 * s + q1 * c;
371 for (
int d = head_dim; d < aligned_head_dim; ++d) {
376 for (
int h = 0; h < num_kv_heads; ++h) {
377 float *k_h = k + (size_t)h * (
size_t)aligned_head_dim;
379 #if defined(__AVX512F__)
380 for (; i + 16 <= half; i += 16) {
381 __m512 k0 = _mm512_loadu_ps(k_h + i);
382 __m512 k1 = _mm512_loadu_ps(k_h + i + half);
383 __m512 cos = _mm512_loadu_ps(cos_ptr + i);
384 __m512 sin = _mm512_loadu_ps(sin_ptr + i);
386 __m512 k_rot0 = _mm512_sub_ps(_mm512_mul_ps(k0, cos), _mm512_mul_ps(k1, sin));
387 __m512 k_rot1 = _mm512_add_ps(_mm512_mul_ps(k0, sin), _mm512_mul_ps(k1, cos));
389 _mm512_storeu_ps(k_h + i, k_rot0);
390 _mm512_storeu_ps(k_h + i + half, k_rot1);
392 #elif defined(__AVX__)
393 for (; i + 8 <= half; i += 8) {
394 __m256 k0 = _mm256_loadu_ps(k_h + i);
395 __m256 k1 = _mm256_loadu_ps(k_h + i + half);
396 __m256 cos = _mm256_loadu_ps(cos_ptr + i);
397 __m256 sin = _mm256_loadu_ps(sin_ptr + i);
399 __m256 k_rot0 = _mm256_sub_ps(_mm256_mul_ps(k0, cos), _mm256_mul_ps(k1, sin));
400 __m256 k_rot1 = _mm256_add_ps(_mm256_mul_ps(k0, sin), _mm256_mul_ps(k1, cos));
402 _mm256_storeu_ps(k_h + i, k_rot0);
403 _mm256_storeu_ps(k_h + i + half, k_rot1);
406 for (; i < half; ++i) {
408 float k1 = k_h[i + half];
409 float c = cos_ptr[i];
410 float s = sin_ptr[i];
411 k_h[i] = k0 * c - k1 * s;
412 k_h[i + half] = k0 * s + k1 * c;
414 for (
int d = head_dim; d < aligned_head_dim; ++d) {
447 const float *kv_cache_k,
448 const float *kv_cache_v,
454 int aligned_head_dim)
456 const int hd = head_dim;
457 const float scale = 1.0f / sqrtf((
float)hd);
458 const size_t head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
460 for (
int h = 0; h < num_heads; h++) {
461 const float *q_h = q + (size_t)h * (
size_t)aligned_head_dim;
462 const int kv_idx = h % num_kv_heads;
463 const float *k_cache = kv_cache_k + (size_t)kv_idx * head_stride;
464 const float *v_cache = kv_cache_v + (size_t)kv_idx * head_stride;
467 float o_h[aligned_head_dim];
472 memset(o_h, 0, (
size_t)aligned_head_dim *
sizeof(
float));
477 if (tile_end > seq_len) tile_end = seq_len;
478 int tile_size = tile_end - t;
482 for (
int i = 0; i < tile_size; i++) {
483 memcpy(k_tile + (
size_t)i * (
size_t)hd,
484 k_cache + (
size_t)(t + i) * (
size_t)aligned_head_dim,
485 (
size_t)hd *
sizeof(
float));
490 for (
int j = 0; j < tile_size; j++) {
492 for (
int i = 0; i < hd; i++) {
493 s_row[j] += q_h[i] * k_tile[j * hd + i];
500 for (
int j = 0; j < tile_size; j++) {
501 if (s_row[j] > m_new) m_new = s_row[j];
505 for (
int j = 0; j < tile_size; j++) {
506 float p = expf(s_row[j] - m_new);
512 float exp_m_diff = expf(m - m_new);
513 for (
int i = 0; i < hd; i++) {
514 o_h[i] *= exp_m_diff;
518 for (
int j = 0; j < tile_size; j++) {
520 for (
int i = 0; i < hd; i++) {
521 o_h[i] += p * v_cache[(size_t)(t + j) * (size_t)aligned_head_dim + (
size_t)i];
525 l = l * exp_m_diff + l_new;
530 for (
int i = 0; i < hd; i++) {
533 for (
int i = hd; i < aligned_head_dim; ++i) {
538 memcpy(o_out + (
size_t)h * (
size_t)aligned_head_dim,
540 (
size_t)aligned_head_dim *
sizeof(
float));
552 const float *attn_token,
555 const float *residual,
558 int aligned_embed_dim,
561 int aligned_head_dim)
563 if (!attn_token || !wo || !output) {
567 const size_t head_weight_stride = (size_t)aligned_embed_dim * (
size_t)aligned_head_dim;
569 for (
int j = 0; j < embed_dim; ++j) {
570 float sum = bo ? bo[j] : 0.0f;
571 for (
int h = 0; h < num_heads; ++h) {
572 const float *o_h = attn_token + (size_t)h * (
size_t)aligned_head_dim;
573 const float *wo_row = wo + (size_t)h * head_weight_stride + (
size_t)j * (size_t)aligned_head_dim;
576 output[j] = sum + (residual ? residual[j] : 0.0f);
579 for (
int j = embed_dim; j < aligned_embed_dim; ++j) {
592 const float *residual,
593 const float *ln1_gamma,
594 const float *wq,
const float *bq,
595 const float *wk,
const float *bk,
596 const float *wv,
const float *bv,
597 const float *wo,
const float *bo,
600 const float *rope_cos,
601 const float *rope_sin,
604 int aligned_embed_dim,
608 int aligned_head_dim,
612 if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
613 !kv_cache_k || !kv_cache_v) {
616 if (embed_dim <= 0 || aligned_embed_dim <= 0 || head_dim <= 0 || aligned_head_dim <= 0 ||
617 num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
620 if (pos < 0 || pos >= cache_capacity) {
623 if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
627 const size_t q_elems = (size_t)num_heads * (
size_t)aligned_head_dim;
628 const size_t kv_elems = (size_t)num_kv_heads * (
size_t)aligned_head_dim;
646 q = (
float *)malloc(q_elems *
sizeof(
float));
653 k = (
float *)malloc(kv_elems *
sizeof(
float));
658 v = (
float *)malloc(kv_elems *
sizeof(
float));
668 o = (
float *)malloc(q_elems *
sizeof(
float));
679 wq, bq, wk, bk, wv, bv,
680 embed_dim, aligned_embed_dim,
681 num_heads, num_kv_heads,
682 head_dim, aligned_head_dim, eps);
684 if (rope_cos && rope_sin) {
686 num_heads, num_kv_heads,
687 head_dim, aligned_head_dim);
691 kv_cache_k, kv_cache_v,
694 head_dim, aligned_head_dim);
697 num_heads, num_kv_heads,
698 pos + 1, cache_capacity,
699 head_dim, aligned_head_dim);
702 embed_dim, aligned_embed_dim,
703 num_heads, head_dim, aligned_head_dim);
CPU feature detection and dispatch macros.
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
Quantization block structures for weight-only quantization.
static void mega_fuse_output_proj_residual(const float *attn_token, const float *wo, const float *bo, const float *residual, float *output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)
void mega_fuse_flash_attention_avx(float *o_out, const float *q, const float *kv_cache_k, const float *kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)
Flash attention with online softmax (AVX version)
static float ck_dot_f32(const float *a, const float *b, int len)
void mega_fuse_rmsnorm_qkv_avx(float *q_out, float *k_out, float *v_out, const float *input, const float *gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)
Fused RMSNorm + QKV for decode (single token)
void mega_fuse_rope_inplace_avx(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)
Apply RoPE to Q and K (in-place, from L1)
void mega_fused_attention_decode(float *output, const float *input, const float *residual, const float *ln1_gamma, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, const float *wo, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
Full mega-fused attention for decode.