67 #include <immintrin.h>
80 __m256 vsum = _mm256_setzero_ps();
82 for (; i + 7 < n; i += 8) {
83 __m256 vx = _mm256_loadu_ps(x + i);
84 vsum = _mm256_fmadd_ps(vx, vx, vsum);
86 __m128 vlow = _mm256_castps256_ps128(vsum);
87 __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
88 vlow = _mm_add_ps(vlow, vhigh);
89 vlow = _mm_hadd_ps(vlow, vlow);
90 vlow = _mm_hadd_ps(vlow, vlow);
91 sum_sq = _mm_cvtss_f32(vlow);
93 sum_sq += x[i] * x[i];
96 for (
int i = 0; i < n; i++) {
97 sum_sq += x[i] * x[i];
101 float rms = sqrtf(sum_sq / (
float)n + eps);
110 return x / (1.0f + expf(-x));
114 static inline __m256 silu_avx2(__m256 x) {
116 __m256 neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
119 neg_x = _mm256_max_ps(neg_x, _mm256_set1_ps(-88.0f));
120 neg_x = _mm256_min_ps(neg_x, _mm256_set1_ps(88.0f));
123 __m256 one = _mm256_set1_ps(1.0f);
124 __m256 c1 = _mm256_set1_ps(0.5f);
125 __m256 c2 = _mm256_set1_ps(0.166666667f);
126 __m256 c3 = _mm256_set1_ps(0.041666667f);
127 __m256 c4 = _mm256_set1_ps(0.008333333f);
129 __m256 x2 = _mm256_mul_ps(neg_x, neg_x);
130 __m256 x3 = _mm256_mul_ps(x2, neg_x);
131 __m256 x4 = _mm256_mul_ps(x2, x2);
133 __m256 exp_neg = _mm256_add_ps(one, neg_x);
134 exp_neg = _mm256_fmadd_ps(c1, x2, exp_neg);
135 exp_neg = _mm256_fmadd_ps(c2, x3, exp_neg);
136 exp_neg = _mm256_fmadd_ps(c3, x4, exp_neg);
139 __m256 sigmoid = _mm256_div_ps(one, _mm256_add_ps(one, exp_neg));
142 return _mm256_mul_ps(x, sigmoid);
151 float max_val = x[0];
152 for (
int i = 1; i < n; i++) {
153 if (x[i] > max_val) max_val = x[i];
157 for (
int i = 0; i < n; i++) {
158 x[i] = expf(x[i] - max_val);
162 float inv_sum = 1.0f / sum;
163 for (
int i = 0; i < n; i++) {
178 const float *k_cache,
179 const float *v_cache,
190 const float *residual_1,
193 const float *rms_weight,
206 int intermediate_dim,
211 const int heads_per_kv = num_heads / num_kv_heads;
212 const int q_dim = num_heads * head_dim;
213 const int kv_dim = num_kv_heads * head_dim;
216 float attn_out[4096];
217 float hidden_after_attn[4096];
219 float gate_out[16384];
222 if (embed_dim > 4096 || intermediate_dim > 16384) {
230 memset(attn_out, 0, q_dim *
sizeof(
float));
232 for (
int h = 0; h < num_heads; h++) {
233 int kv_h = h / heads_per_kv;
235 const float *q_head = q + h * head_dim;
236 float *out_head = attn_out + h * head_dim;
240 if (seq_len > 8192)
return;
242 for (
int t = 0; t < seq_len; t++) {
243 const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
245 for (
int d = 0; d < head_dim; d++) {
246 score += q_head[d] * k_t[d];
248 scores[t] =
score * attn_scale;
255 for (
int t = 0; t < seq_len; t++) {
256 const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
258 for (
int d = 0; d < head_dim; d++) {
259 out_head[d] += w * v_t[d];
268 for (
int i = 0; i < embed_dim; i++) {
270 const float *wo_row = wo + i * q_dim;
271 for (
int j = 0; j < q_dim; j++) {
272 sum += wo_row[j] * attn_out[j];
274 hidden_after_attn[i] = sum + residual_1[i];
284 __m256 vscale = _mm256_set1_ps(rms_scale);
286 for (; i + 7 < embed_dim; i += 8) {
287 __m256 vh = _mm256_loadu_ps(hidden_after_attn + i);
288 __m256 vw = _mm256_loadu_ps(rms_weight + i);
289 __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
290 _mm256_storeu_ps(normed + i, vn);
292 for (; i < embed_dim; i++) {
293 normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
296 for (
int i = 0; i < embed_dim; i++) {
297 normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
306 for (
int i = 0; i < intermediate_dim; i++) {
308 const float *wg_row = w_gate + i * embed_dim;
309 for (
int j = 0; j < embed_dim; j++) {
310 sum += wg_row[j] * normed[j];
316 for (
int i = 0; i < intermediate_dim; i++) {
318 const float *wu_row = w_up + i * embed_dim;
319 for (
int j = 0; j < embed_dim; j++) {
320 sum += wu_row[j] * normed[j];
331 for (; i + 7 < intermediate_dim; i += 8) {
332 __m256 vg = _mm256_loadu_ps(gate_out + i);
333 __m256 vu = _mm256_loadu_ps(up_out + i);
334 __m256 vsilu = silu_avx2(vg);
335 __m256 vswiglu = _mm256_mul_ps(vsilu, vu);
336 _mm256_storeu_ps(gate_out + i, vswiglu);
338 for (; i < intermediate_dim; i++) {
339 gate_out[i] =
silu_scalar(gate_out[i]) * up_out[i];
342 for (
int i = 0; i < intermediate_dim; i++) {
343 gate_out[i] =
silu_scalar(gate_out[i]) * up_out[i];
351 for (
int i = 0; i < embed_dim; i++) {
353 const float *wd_row = w_down + i * intermediate_dim;
354 for (
int j = 0; j < intermediate_dim; j++) {
355 sum += wd_row[j] * gate_out[j];
357 hidden_out[i] = sum + hidden_after_attn[i];
374 static inline float gemv_fp32_row_avx2(
379 __m256 acc = _mm256_setzero_ps();
382 for (; k + 7 < K; k += 8) {
383 __m256 vw = _mm256_loadu_ps(row + k);
384 __m256 vx = _mm256_loadu_ps(x + k);
385 acc = _mm256_fmadd_ps(vw, vx, acc);
389 __m128 vlow = _mm256_castps256_ps128(acc);
390 __m128 vhigh = _mm256_extractf128_ps(acc, 1);
391 vlow = _mm_add_ps(vlow, vhigh);
392 __m128 shuf = _mm_movehdup_ps(vlow);
393 vlow = _mm_add_ps(vlow, shuf);
394 shuf = _mm_movehl_ps(shuf, vlow);
395 vlow = _mm_add_ss(vlow, shuf);
396 float sum = _mm_cvtss_f32(vlow);
400 sum += row[k] * x[k];
409 const float *hidden_in,
412 const float *rms_weight,
422 int intermediate_dim,
431 if (embed_dim > 4096 || intermediate_dim > 16384) {
442 __m256 vscale = _mm256_set1_ps(rms_scale);
444 for (; i + 7 < embed_dim; i += 8) {
445 __m256 vh = _mm256_loadu_ps(hidden_in + i);
446 __m256 vw = _mm256_loadu_ps(rms_weight + i);
447 __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
448 _mm256_storeu_ps(normed + i, vn);
450 for (; i < embed_dim; i++) {
451 normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
454 for (
int i = 0; i < embed_dim; i++) {
455 normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
467 for (
int j = 0; j < intermediate_dim; j++) {
469 const float *wg_row = w_gate + j * embed_dim;
470 const float *wu_row = w_up + j * embed_dim;
472 __m256 gate_acc = _mm256_setzero_ps();
473 __m256 up_acc = _mm256_setzero_ps();
476 for (; k + 7 < embed_dim; k += 8) {
477 __m256 vn = _mm256_loadu_ps(normed + k);
478 __m256 vwg = _mm256_loadu_ps(wg_row + k);
479 __m256 vwu = _mm256_loadu_ps(wu_row + k);
481 gate_acc = _mm256_fmadd_ps(vwg, vn, gate_acc);
482 up_acc = _mm256_fmadd_ps(vwu, vn, up_acc);
486 __m128 glow = _mm256_castps256_ps128(gate_acc);
487 __m128 ghigh = _mm256_extractf128_ps(gate_acc, 1);
488 glow = _mm_add_ps(glow, ghigh);
489 __m128 gshuf = _mm_movehdup_ps(glow);
490 glow = _mm_add_ps(glow, gshuf);
491 gshuf = _mm_movehl_ps(gshuf, glow);
492 glow = _mm_add_ss(glow, gshuf);
493 float gate_val = _mm_cvtss_f32(glow);
495 __m128 ulow = _mm256_castps256_ps128(up_acc);
496 __m128 uhigh = _mm256_extractf128_ps(up_acc, 1);
497 ulow = _mm_add_ps(ulow, uhigh);
498 __m128 ushuf = _mm_movehdup_ps(ulow);
499 ulow = _mm_add_ps(ulow, ushuf);
500 ushuf = _mm_movehl_ps(ushuf, ulow);
501 ulow = _mm_add_ss(ulow, ushuf);
502 float up_val = _mm_cvtss_f32(ulow);
505 for (; k < embed_dim; k++) {
506 gate_val += wg_row[k] * normed[k];
507 up_val += wu_row[k] * normed[k];
514 for (
int j = 0; j < intermediate_dim; j++) {
515 const float *wg_row = w_gate + j * embed_dim;
516 const float *wu_row = w_up + j * embed_dim;
517 float gate_val = 0.0f, up_val = 0.0f;
519 for (
int k = 0; k < embed_dim; k++) {
520 gate_val += wg_row[k] * normed[k];
521 up_val += wu_row[k] * normed[k];
533 for (
int j = 0; j < embed_dim; j++) {
534 float sum = gemv_fp32_row_avx2(w_down + j * intermediate_dim, swiglu, intermediate_dim);
535 hidden_out[j] = sum + hidden_in[j];
538 for (
int j = 0; j < embed_dim; j++) {
540 const float *wd_row = w_down + j * intermediate_dim;
541 for (
int k = 0; k < intermediate_dim; k++) {
542 sum += wd_row[k] * swiglu[k];
544 hidden_out[j] = sum + hidden_in[j];
565 const float *hidden_in,
566 const float *rms_weight,
572 int intermediate_dim,
577 float gate_out[16384];
580 if (embed_dim > 4096 || intermediate_dim > 16384) {
591 __m256 vscale = _mm256_set1_ps(rms_scale);
593 for (; i + 7 < embed_dim; i += 8) {
594 __m256 vh = _mm256_loadu_ps(hidden_in + i);
595 __m256 vw = _mm256_loadu_ps(rms_weight + i);
596 __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
597 _mm256_storeu_ps(normed + i, vn);
599 for (; i < embed_dim; i++) {
600 normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
603 for (
int i = 0; i < embed_dim; i++) {
604 normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
613 for (
int j = 0; j < intermediate_dim; j++) {
614 gate_out[j] = gemv_fp32_row_avx2(w_gate + j * embed_dim, normed, embed_dim);
617 for (
int j = 0; j < intermediate_dim; j++) {
619 const float *wg_row = w_gate + j * embed_dim;
620 for (
int k = 0; k < embed_dim; k++) {
621 sum += wg_row[k] * normed[k];
635 for (
int j = 0; j < intermediate_dim; j++) {
636 float up_val = gemv_fp32_row_avx2(w_up + j * embed_dim, normed, embed_dim);
641 for (
int j = 0; j < intermediate_dim; j++) {
643 const float *wu_row = w_up + j * embed_dim;
644 for (
int k = 0; k < embed_dim; k++) {
645 up_val += wu_row[k] * normed[k];
656 for (
int j = 0; j < embed_dim; j++) {
657 float sum = gemv_fp32_row_avx2(w_down + j * intermediate_dim, swiglu, intermediate_dim);
658 hidden_out[j] = sum + hidden_in[j];
661 for (
int j = 0; j < embed_dim; j++) {
663 const float *wd_row = w_down + j * intermediate_dim;
664 for (
int k = 0; k < intermediate_dim; k++) {
665 sum += wd_row[k] * swiglu[k];
667 hidden_out[j] = sum + hidden_in[j];
680 const float *hidden_in,
681 const float *rms_weight,
690 int intermediate_dim,
695 for (
int i = 0; i < embed_dim; i++) {
696 normed_buf[i] = hidden_in[i] * rms_weight[i] * rms_scale;
700 for (
int j = 0; j < intermediate_dim; j++) {
702 const float *wg_row = w_gate + j * embed_dim;
703 for (
int k = 0; k < embed_dim; k++) {
704 sum += wg_row[k] * normed_buf[k];
710 for (
int j = 0; j < intermediate_dim; j++) {
712 const float *wu_row = w_up + j * embed_dim;
713 for (
int k = 0; k < embed_dim; k++) {
714 sum += wu_row[k] * normed_buf[k];
720 for (
int j = 0; j < intermediate_dim; j++) {
721 gate_buf[j] =
silu_scalar(gate_buf[j]) * up_buf[j];
725 for (
int j = 0; j < embed_dim; j++) {
727 const float *wd_row = w_down + j * intermediate_dim;
728 for (
int k = 0; k < intermediate_dim; k++) {
729 sum += wd_row[k] * gate_buf[k];
731 hidden_out[j] = sum + hidden_in[j];
745 const float *k_cache,
746 const float *v_cache,
757 const float *residual_1,
760 const float *rms_weight,
770 int intermediate_dim,
775 const int heads_per_kv = num_heads / num_kv_heads;
776 const int q_dim = num_heads * head_dim;
777 const int kv_dim = num_kv_heads * head_dim;
780 float attn_out[4096];
781 float hidden_after_attn[4096];
785 if (embed_dim > 4096)
return;
791 memset(attn_out, 0, q_dim *
sizeof(
float));
793 for (
int h = 0; h < num_heads; h++) {
794 int kv_h = h / heads_per_kv;
795 const float *q_head = q + h * head_dim;
796 float *out_head = attn_out + h * head_dim;
799 if (seq_len > 8192)
return;
801 for (
int t = 0; t < seq_len; t++) {
802 const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
804 for (
int d = 0; d < head_dim; d++) {
805 score += q_head[d] * k_t[d];
807 scores[t] =
score * attn_scale;
812 for (
int t = 0; t < seq_len; t++) {
813 const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
815 for (
int d = 0; d < head_dim; d++) {
816 out_head[d] += w * v_t[d];
825 extern void gemv_q4_k(
float *y,
const void *W,
const float *x,
int M,
int K);
827 gemv_q4_k(hidden_after_attn, wo, attn_out, embed_dim, q_dim);
830 for (
int i = 0; i < embed_dim; i++) {
831 hidden_after_attn[i] += residual_1[i];
840 for (
int i = 0; i < embed_dim; i++) {
841 normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
853 float gate_out[16384];
856 if (intermediate_dim > 16384)
return;
859 gemv_q4_k(gate_out, w_gate, normed, intermediate_dim, embed_dim);
862 gemv_q4_k(up_out, w_up, normed, intermediate_dim, embed_dim);
867 for (; i + 7 < intermediate_dim; i += 8) {
868 __m256 vg = _mm256_loadu_ps(gate_out + i);
869 __m256 vu = _mm256_loadu_ps(up_out + i);
870 __m256 vsilu = silu_avx2(vg);
871 __m256 vswiglu = _mm256_mul_ps(vsilu, vu);
872 _mm256_storeu_ps(gate_out + i, vswiglu);
874 for (; i < intermediate_dim; i++) {
875 gate_out[i] =
silu_scalar(gate_out[i]) * up_out[i];
878 for (
int i = 0; i < intermediate_dim; i++) {
879 gate_out[i] =
silu_scalar(gate_out[i]) * up_out[i];
884 gemv_q4_k(mlp_out, w_down, gate_out, embed_dim, intermediate_dim);
887 for (
int i = 0; i < embed_dim; i++) {
888 hidden_out[i] = mlp_out[i] + hidden_after_attn[i];
904 const float *k_cache,
905 const float *v_cache,
911 const float *rms_weight_mlp,
917 const float *rms_weight_attn,
923 const float *residual_in,
927 int intermediate_dim,
939 extern void gemv_q4_k(
float *y,
const void *W,
const float *x,
int M,
int K);
941 const int heads_per_kv = num_heads / num_kv_heads;
942 const int q_dim = num_heads * head_dim;
943 const int kv_dim = num_kv_heads * head_dim;
947 float attn_out[4096];
948 float hidden_after_attn[4096];
949 float normed_mlp[4096];
950 float gate_out[16384];
953 float normed_attn[4096];
955 if (embed_dim > 4096 || intermediate_dim > 16384)
return;
961 memset(attn_out, 0, q_dim *
sizeof(
float));
963 for (
int h = 0; h < num_heads; h++) {
964 int kv_h = h / heads_per_kv;
965 const float *q_head = q + h * head_dim;
966 float *out_head = attn_out + h * head_dim;
969 if (seq_len > 8192)
return;
971 for (
int t = 0; t < seq_len; t++) {
972 const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
974 for (
int d = 0; d < head_dim; d++) {
975 score += q_head[d] * k_t[d];
977 scores[t] =
score * attn_scale;
981 float max_score = scores[0];
982 for (
int t = 1; t < seq_len; t++) {
983 if (scores[t] > max_score) max_score = scores[t];
985 float sum_exp = 0.0f;
986 for (
int t = 0; t < seq_len; t++) {
987 scores[t] = expf(scores[t] - max_score);
988 sum_exp += scores[t];
990 float inv_sum = 1.0f / sum_exp;
991 for (
int t = 0; t < seq_len; t++) {
992 scores[t] *= inv_sum;
996 for (
int t = 0; t < seq_len; t++) {
997 const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
999 for (
int d = 0; d < head_dim; d++) {
1000 out_head[d] += w * v_t[d];
1009 gemv_q4_k(hidden_after_attn, wo, attn_out, embed_dim, q_dim);
1011 for (
int i = 0; i < embed_dim; i++) {
1012 hidden_after_attn[i] += residual_in[i];
1019 float sum_sq = 0.0f;
1020 for (
int i = 0; i < embed_dim; i++) {
1021 sum_sq += hidden_after_attn[i] * hidden_after_attn[i];
1023 float rms_scale = 1.0f / sqrtf(sum_sq / embed_dim + eps);
1025 for (
int i = 0; i < embed_dim; i++) {
1026 normed_mlp[i] = hidden_after_attn[i] * rms_weight_mlp[i] * rms_scale;
1033 gemv_q4_k(gate_out, w_gate, normed_mlp, intermediate_dim, embed_dim);
1034 gemv_q4_k(up_out, w_up, normed_mlp, intermediate_dim, embed_dim);
1037 for (
int i = 0; i < intermediate_dim; i++) {
1038 float g = gate_out[i];
1039 float silu_g = g / (1.0f + expf(-g));
1040 gate_out[i] = silu_g * up_out[i];
1044 gemv_q4_k(hidden_out, w_down, gate_out, embed_dim, intermediate_dim);
1047 for (
int i = 0; i < embed_dim; i++) {
1048 hidden_out[i] += hidden_after_attn[i];
1057 for (
int i = 0; i < embed_dim; i++) {
1058 sum_sq += hidden_out[i] * hidden_out[i];
1060 rms_scale = 1.0f / sqrtf(sum_sq / embed_dim + eps);
1062 for (
int i = 0; i < embed_dim; i++) {
1063 normed_attn[i] = hidden_out[i] * rms_weight_attn[i] * rms_scale;
1073 gemv_q4_k(q_next, wq_next, normed_attn, q_dim, embed_dim);
1074 gemv_q4_k(k_next, wk_next, normed_attn, kv_dim, embed_dim);
1075 gemv_q4_k(v_next, wv_next, normed_attn, kv_dim, embed_dim);
1085 const float *q,
const float *k_cache,
const float *v_cache,
1086 int seq_len,
int num_heads,
int num_kv_heads,
int head_dim,
1088 const float *wo,
const float *residual_1,
1089 const float *rms_weight,
float eps,
1090 const float *w_gate,
const float *w_up,
const float *w_down,
1091 int embed_dim,
int intermediate_dim,
1093 float *attn_out_buf,
1094 float *hidden_after_attn_buf,
1105 const int heads_per_kv = num_heads / num_kv_heads;
1106 const int q_dim = num_heads * head_dim;
1107 const int kv_dim = num_kv_heads * head_dim;
1110 memset(attn_out_buf, 0, q_dim *
sizeof(
float));
1114 if (seq_len > 8192)
return;
1116 for (
int h = 0; h < num_heads; h++) {
1117 int kv_h = h / heads_per_kv;
1118 const float *q_head = q + h * head_dim;
1119 float *out_head = attn_out_buf + h * head_dim;
1121 for (
int t = 0; t < seq_len; t++) {
1122 const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
1124 for (
int d = 0; d < head_dim; d++) {
1125 score += q_head[d] * k_t[d];
1127 scores[t] =
score * attn_scale;
1132 for (
int t = 0; t < seq_len; t++) {
1133 const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
1134 float w = scores[t];
1135 for (
int d = 0; d < head_dim; d++) {
1136 out_head[d] += w * v_t[d];
1142 for (
int i = 0; i < embed_dim; i++) {
1144 const float *wo_row = wo + i * q_dim;
1145 for (
int j = 0; j < q_dim; j++) {
1146 sum += wo_row[j] * attn_out_buf[j];
1148 hidden_after_attn_buf[i] = sum + residual_1[i];
1153 for (
int i = 0; i < embed_dim; i++) {
1154 normed_buf[i] = hidden_after_attn_buf[i] * rms_weight[i] * rms_scale;
1158 for (
int i = 0; i < intermediate_dim; i++) {
1160 const float *wg_row = w_gate + i * embed_dim;
1161 for (
int j = 0; j < embed_dim; j++) {
1162 sum += wg_row[j] * normed_buf[j];
1168 for (
int i = 0; i < intermediate_dim; i++) {
1170 const float *wu_row = w_up + i * embed_dim;
1171 for (
int j = 0; j < embed_dim; j++) {
1172 sum += wu_row[j] * normed_buf[j];
1178 for (
int i = 0; i < intermediate_dim; i++) {
1179 gate_buf[i] =
silu_scalar(gate_buf[i]) * up_buf[i];
1183 for (
int i = 0; i < embed_dim; i++) {
1185 const float *wd_row = w_down + i * intermediate_dim;
1186 for (
int j = 0; j < intermediate_dim; j++) {
1187 sum += wd_row[j] * gate_buf[j];
1189 mlp_out_buf[i] = sum;
1193 for (
int i = 0; i < embed_dim; i++) {
1194 hidden_out[i] = mlp_out_buf[i] + hidden_after_attn_buf[i];
void attention_mlp_fused_q4k(const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const void *wo, const float *residual_1, const float *rms_weight, float eps, const void *w_gate, const void *w_up, const void *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
void mlp_fused_fp32_v2(const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
static float silu_scalar(float x)
void mlp_separate_fp32(const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, float *normed_buf, float *gate_buf, float *up_buf, int embed_dim, int intermediate_dim, float *hidden_out)
static float compute_rms_scale_internal(const float *x, int n, float eps)
void mlp_fused_fp32_v3(const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
static void softmax_inplace(float *x, int n)
void layer_fused_attn_mlp_qkv_q4k(const float *q, const float *k_cache, const float *v_cache, int seq_len, float attn_scale, const void *wo, const float *rms_weight_mlp, const void *w_gate, const void *w_up, const void *w_down, const float *rms_weight_attn, const void *wq_next, const void *wk_next, const void *wv_next, const float *residual_in, int embed_dim, int intermediate_dim, int num_heads, int num_kv_heads, int head_dim, float eps, float *q_next, float *k_next, float *v_next, float *hidden_out)
void attention_mlp_separate_fp32(const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *attn_out_buf, float *hidden_after_attn_buf, float *normed_buf, float *gate_buf, float *up_buf, float *mlp_out_buf, float *hidden_out)
void attention_mlp_fused_fp32(const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
Quantization block structures for weight-only quantization.