36 #include <immintrin.h>
49 __m256 vsum = _mm256_setzero_ps();
51 for (; i + 7 < n; i += 8) {
52 __m256 vx = _mm256_loadu_ps(x + i);
53 vsum = _mm256_fmadd_ps(vx, vx, vsum);
56 __m128 vlow = _mm256_castps256_ps128(vsum);
57 __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
58 vlow = _mm_add_ps(vlow, vhigh);
59 vlow = _mm_hadd_ps(vlow, vlow);
60 vlow = _mm_hadd_ps(vlow, vlow);
61 sum_sq = _mm_cvtss_f32(vlow);
64 sum_sq += x[i] * x[i];
67 for (
int i = 0; i < n; i++) {
68 sum_sq += x[i] * x[i];
72 float rms = sqrtf(sum_sq / (
float)n + eps);
82 const float *rms_weight,
108 for (
int j = 0; j < q_dim; j++) {
110 const float *wq_row = wq + j * embed_dim;
113 __m256 vsum = _mm256_setzero_ps();
115 for (; i + 7 < embed_dim; i += 8) {
116 __m256 vx = _mm256_loadu_ps(x + i);
117 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
118 __m256 vw = _mm256_loadu_ps(wq_row + i);
119 __m256 vnormed = _mm256_mul_ps(vx, vrms);
120 vsum = _mm256_fmadd_ps(vw, vnormed, vsum);
123 __m128 vlow = _mm256_castps256_ps128(vsum);
124 __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
125 vlow = _mm_add_ps(vlow, vhigh);
126 vlow = _mm_hadd_ps(vlow, vlow);
127 vlow = _mm_hadd_ps(vlow, vlow);
128 sum = _mm_cvtss_f32(vlow);
130 for (; i < embed_dim; i++) {
131 sum += wq_row[i] * x[i] * rms_weight[i];
134 for (
int i = 0; i < embed_dim; i++) {
135 sum += wq_row[i] * x[i] * rms_weight[i];
138 q_out[j] = sum * scale;
142 for (
int j = 0; j < kv_dim; j++) {
144 const float *wk_row = wk + j * embed_dim;
147 __m256 vsum = _mm256_setzero_ps();
149 for (; i + 7 < embed_dim; i += 8) {
150 __m256 vx = _mm256_loadu_ps(x + i);
151 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
152 __m256 vw = _mm256_loadu_ps(wk_row + i);
153 __m256 vnormed = _mm256_mul_ps(vx, vrms);
154 vsum = _mm256_fmadd_ps(vw, vnormed, vsum);
156 __m128 vlow = _mm256_castps256_ps128(vsum);
157 __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
158 vlow = _mm_add_ps(vlow, vhigh);
159 vlow = _mm_hadd_ps(vlow, vlow);
160 vlow = _mm_hadd_ps(vlow, vlow);
161 sum = _mm_cvtss_f32(vlow);
162 for (; i < embed_dim; i++) {
163 sum += wk_row[i] * x[i] * rms_weight[i];
166 for (
int i = 0; i < embed_dim; i++) {
167 sum += wk_row[i] * x[i] * rms_weight[i];
170 k_out[j] = sum * scale;
174 for (
int j = 0; j < kv_dim; j++) {
176 const float *wv_row = wv + j * embed_dim;
179 __m256 vsum = _mm256_setzero_ps();
181 for (; i + 7 < embed_dim; i += 8) {
182 __m256 vx = _mm256_loadu_ps(x + i);
183 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
184 __m256 vw = _mm256_loadu_ps(wv_row + i);
185 __m256 vnormed = _mm256_mul_ps(vx, vrms);
186 vsum = _mm256_fmadd_ps(vw, vnormed, vsum);
188 __m128 vlow = _mm256_castps256_ps128(vsum);
189 __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
190 vlow = _mm_add_ps(vlow, vhigh);
191 vlow = _mm_hadd_ps(vlow, vlow);
192 vlow = _mm_hadd_ps(vlow, vlow);
193 sum = _mm_cvtss_f32(vlow);
194 for (; i < embed_dim; i++) {
195 sum += wv_row[i] * x[i] * rms_weight[i];
198 for (
int i = 0; i < embed_dim; i++) {
199 sum += wv_row[i] * x[i] * rms_weight[i];
202 v_out[j] = sum * scale;
215 const float *rms_weight,
242 if (embed_dim > 4096) {
249 __m256 vscale = _mm256_set1_ps(scale);
251 for (; i + 7 < embed_dim; i += 8) {
252 __m256 vx = _mm256_loadu_ps(x + i);
253 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
254 __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
255 _mm256_storeu_ps(normed + i, vn);
257 for (; i < embed_dim; i++) {
258 normed[i] = x[i] * rms_weight[i] * scale;
261 for (
int i = 0; i < embed_dim; i++) {
262 normed[i] = x[i] * rms_weight[i] * scale;
277 extern void gemv_q4_k(
float *y,
const void *W,
const float *x,
int M,
int K);
280 gemv_q4_k(q_out, wq, normed, q_dim, embed_dim);
283 gemv_q4_k(k_out, wk, normed, kv_dim, embed_dim);
286 gemv_q4_k(v_out, wv, normed, kv_dim, embed_dim);
309 static inline float hsum256_ps(__m256 v) {
310 __m128 vlow = _mm256_castps256_ps128(v);
311 __m128 vhigh = _mm256_extractf128_ps(v, 1);
312 vlow = _mm_add_ps(vlow, vhigh);
313 __m128 shuf = _mm_movehdup_ps(vlow);
314 vlow = _mm_add_ps(vlow, shuf);
315 shuf = _mm_movehl_ps(shuf, vlow);
316 vlow = _mm_add_ss(vlow, shuf);
317 return _mm_cvtss_f32(vlow);
323 const float *rms_weight,
339 __m256 vscale = _mm256_set1_ps(scale);
352 for (
int j = 0; j < q_dim; j += 8) {
354 __m256 acc0 = _mm256_setzero_ps();
355 __m256 acc1 = _mm256_setzero_ps();
356 __m256 acc2 = _mm256_setzero_ps();
357 __m256 acc3 = _mm256_setzero_ps();
358 __m256 acc4 = _mm256_setzero_ps();
359 __m256 acc5 = _mm256_setzero_ps();
360 __m256 acc6 = _mm256_setzero_ps();
361 __m256 acc7 = _mm256_setzero_ps();
365 for (; i + 7 < embed_dim; i += 8) {
367 __m256 vx = _mm256_loadu_ps(x + i);
368 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
369 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
375 __m256 w0 = _mm256_loadu_ps(wq + (j+0)*embed_dim + i);
376 acc0 = _mm256_fmadd_ps(w0, normed, acc0);
379 __m256 w1 = _mm256_loadu_ps(wq + (j+1)*embed_dim + i);
380 acc1 = _mm256_fmadd_ps(w1, normed, acc1);
383 __m256 w2 = _mm256_loadu_ps(wq + (j+2)*embed_dim + i);
384 acc2 = _mm256_fmadd_ps(w2, normed, acc2);
387 __m256 w3 = _mm256_loadu_ps(wq + (j+3)*embed_dim + i);
388 acc3 = _mm256_fmadd_ps(w3, normed, acc3);
391 __m256 w4 = _mm256_loadu_ps(wq + (j+4)*embed_dim + i);
392 acc4 = _mm256_fmadd_ps(w4, normed, acc4);
395 __m256 w5 = _mm256_loadu_ps(wq + (j+5)*embed_dim + i);
396 acc5 = _mm256_fmadd_ps(w5, normed, acc5);
399 __m256 w6 = _mm256_loadu_ps(wq + (j+6)*embed_dim + i);
400 acc6 = _mm256_fmadd_ps(w6, normed, acc6);
403 __m256 w7 = _mm256_loadu_ps(wq + (j+7)*embed_dim + i);
404 acc7 = _mm256_fmadd_ps(w7, normed, acc7);
409 for (; i < embed_dim; i++) {
410 float normed_scalar = x[i] * rms_weight[i] * scale;
411 if (j + 0 < q_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wq[(j+0)*embed_dim + i] * normed_scalar));
412 if (j + 1 < q_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wq[(j+1)*embed_dim + i] * normed_scalar));
413 if (j + 2 < q_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wq[(j+2)*embed_dim + i] * normed_scalar));
414 if (j + 3 < q_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wq[(j+3)*embed_dim + i] * normed_scalar));
415 if (j + 4 < q_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wq[(j+4)*embed_dim + i] * normed_scalar));
416 if (j + 5 < q_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wq[(j+5)*embed_dim + i] * normed_scalar));
417 if (j + 6 < q_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wq[(j+6)*embed_dim + i] * normed_scalar));
418 if (j + 7 < q_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wq[(j+7)*embed_dim + i] * normed_scalar));
422 if (j + 0 < q_dim) q_out[j+0] = hsum256_ps(acc0);
423 if (j + 1 < q_dim) q_out[j+1] = hsum256_ps(acc1);
424 if (j + 2 < q_dim) q_out[j+2] = hsum256_ps(acc2);
425 if (j + 3 < q_dim) q_out[j+3] = hsum256_ps(acc3);
426 if (j + 4 < q_dim) q_out[j+4] = hsum256_ps(acc4);
427 if (j + 5 < q_dim) q_out[j+5] = hsum256_ps(acc5);
428 if (j + 6 < q_dim) q_out[j+6] = hsum256_ps(acc6);
429 if (j + 7 < q_dim) q_out[j+7] = hsum256_ps(acc7);
436 for (
int j = 0; j < kv_dim; j += 8) {
437 __m256 acc0 = _mm256_setzero_ps();
438 __m256 acc1 = _mm256_setzero_ps();
439 __m256 acc2 = _mm256_setzero_ps();
440 __m256 acc3 = _mm256_setzero_ps();
441 __m256 acc4 = _mm256_setzero_ps();
442 __m256 acc5 = _mm256_setzero_ps();
443 __m256 acc6 = _mm256_setzero_ps();
444 __m256 acc7 = _mm256_setzero_ps();
447 for (; i + 7 < embed_dim; i += 8) {
448 __m256 vx = _mm256_loadu_ps(x + i);
449 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
450 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
452 if (j + 0 < kv_dim) acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+0)*embed_dim + i), normed, acc0);
453 if (j + 1 < kv_dim) acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+1)*embed_dim + i), normed, acc1);
454 if (j + 2 < kv_dim) acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+2)*embed_dim + i), normed, acc2);
455 if (j + 3 < kv_dim) acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+3)*embed_dim + i), normed, acc3);
456 if (j + 4 < kv_dim) acc4 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+4)*embed_dim + i), normed, acc4);
457 if (j + 5 < kv_dim) acc5 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+5)*embed_dim + i), normed, acc5);
458 if (j + 6 < kv_dim) acc6 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+6)*embed_dim + i), normed, acc6);
459 if (j + 7 < kv_dim) acc7 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+7)*embed_dim + i), normed, acc7);
462 for (; i < embed_dim; i++) {
463 float normed_scalar = x[i] * rms_weight[i] * scale;
464 if (j + 0 < kv_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wk[(j+0)*embed_dim + i] * normed_scalar));
465 if (j + 1 < kv_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wk[(j+1)*embed_dim + i] * normed_scalar));
466 if (j + 2 < kv_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wk[(j+2)*embed_dim + i] * normed_scalar));
467 if (j + 3 < kv_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wk[(j+3)*embed_dim + i] * normed_scalar));
468 if (j + 4 < kv_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wk[(j+4)*embed_dim + i] * normed_scalar));
469 if (j + 5 < kv_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wk[(j+5)*embed_dim + i] * normed_scalar));
470 if (j + 6 < kv_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wk[(j+6)*embed_dim + i] * normed_scalar));
471 if (j + 7 < kv_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wk[(j+7)*embed_dim + i] * normed_scalar));
474 if (j + 0 < kv_dim) k_out[j+0] = hsum256_ps(acc0);
475 if (j + 1 < kv_dim) k_out[j+1] = hsum256_ps(acc1);
476 if (j + 2 < kv_dim) k_out[j+2] = hsum256_ps(acc2);
477 if (j + 3 < kv_dim) k_out[j+3] = hsum256_ps(acc3);
478 if (j + 4 < kv_dim) k_out[j+4] = hsum256_ps(acc4);
479 if (j + 5 < kv_dim) k_out[j+5] = hsum256_ps(acc5);
480 if (j + 6 < kv_dim) k_out[j+6] = hsum256_ps(acc6);
481 if (j + 7 < kv_dim) k_out[j+7] = hsum256_ps(acc7);
488 for (
int j = 0; j < kv_dim; j += 8) {
489 __m256 acc0 = _mm256_setzero_ps();
490 __m256 acc1 = _mm256_setzero_ps();
491 __m256 acc2 = _mm256_setzero_ps();
492 __m256 acc3 = _mm256_setzero_ps();
493 __m256 acc4 = _mm256_setzero_ps();
494 __m256 acc5 = _mm256_setzero_ps();
495 __m256 acc6 = _mm256_setzero_ps();
496 __m256 acc7 = _mm256_setzero_ps();
499 for (; i + 7 < embed_dim; i += 8) {
500 __m256 vx = _mm256_loadu_ps(x + i);
501 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
502 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
504 if (j + 0 < kv_dim) acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+0)*embed_dim + i), normed, acc0);
505 if (j + 1 < kv_dim) acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+1)*embed_dim + i), normed, acc1);
506 if (j + 2 < kv_dim) acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+2)*embed_dim + i), normed, acc2);
507 if (j + 3 < kv_dim) acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+3)*embed_dim + i), normed, acc3);
508 if (j + 4 < kv_dim) acc4 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+4)*embed_dim + i), normed, acc4);
509 if (j + 5 < kv_dim) acc5 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+5)*embed_dim + i), normed, acc5);
510 if (j + 6 < kv_dim) acc6 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+6)*embed_dim + i), normed, acc6);
511 if (j + 7 < kv_dim) acc7 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+7)*embed_dim + i), normed, acc7);
514 for (; i < embed_dim; i++) {
515 float normed_scalar = x[i] * rms_weight[i] * scale;
516 if (j + 0 < kv_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wv[(j+0)*embed_dim + i] * normed_scalar));
517 if (j + 1 < kv_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wv[(j+1)*embed_dim + i] * normed_scalar));
518 if (j + 2 < kv_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wv[(j+2)*embed_dim + i] * normed_scalar));
519 if (j + 3 < kv_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wv[(j+3)*embed_dim + i] * normed_scalar));
520 if (j + 4 < kv_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wv[(j+4)*embed_dim + i] * normed_scalar));
521 if (j + 5 < kv_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wv[(j+5)*embed_dim + i] * normed_scalar));
522 if (j + 6 < kv_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wv[(j+6)*embed_dim + i] * normed_scalar));
523 if (j + 7 < kv_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wv[(j+7)*embed_dim + i] * normed_scalar));
526 if (j + 0 < kv_dim) v_out[j+0] = hsum256_ps(acc0);
527 if (j + 1 < kv_dim) v_out[j+1] = hsum256_ps(acc1);
528 if (j + 2 < kv_dim) v_out[j+2] = hsum256_ps(acc2);
529 if (j + 3 < kv_dim) v_out[j+3] = hsum256_ps(acc3);
530 if (j + 4 < kv_dim) v_out[j+4] = hsum256_ps(acc4);
531 if (j + 5 < kv_dim) v_out[j+5] = hsum256_ps(acc5);
532 if (j + 6 < kv_dim) v_out[j+6] = hsum256_ps(acc6);
533 if (j + 7 < kv_dim) v_out[j+7] = hsum256_ps(acc7);
538 for (
int j = 0; j < q_dim; j++) {
540 for (
int i = 0; i < embed_dim; i++) {
541 float normed = x[i] * rms_weight[i] * scale;
542 sum += wq[j * embed_dim + i] * normed;
546 for (
int j = 0; j < kv_dim; j++) {
548 for (
int i = 0; i < embed_dim; i++) {
549 float normed = x[i] * rms_weight[i] * scale;
550 sum += wk[j * embed_dim + i] * normed;
554 for (
int j = 0; j < kv_dim; j++) {
556 for (
int i = 0; i < embed_dim; i++) {
557 float normed = x[i] * rms_weight[i] * scale;
558 sum += wv[j * embed_dim + i] * normed;
593 const float *rms_weight,
609 __m256 vscale = _mm256_set1_ps(scale);
615 for (
int j = 0; j < kv_dim; j++) {
616 __m256 q_acc = _mm256_setzero_ps();
617 __m256 k_acc = _mm256_setzero_ps();
618 __m256 v_acc = _mm256_setzero_ps();
621 for (; i + 7 < embed_dim; i += 8) {
623 __m256 vx = _mm256_loadu_ps(x + i);
624 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
625 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
628 __m256 wq_row = _mm256_loadu_ps(wq + j * embed_dim + i);
629 __m256 wk_row = _mm256_loadu_ps(wk + j * embed_dim + i);
630 __m256 wv_row = _mm256_loadu_ps(wv + j * embed_dim + i);
633 q_acc = _mm256_fmadd_ps(wq_row, normed, q_acc);
634 k_acc = _mm256_fmadd_ps(wk_row, normed, k_acc);
635 v_acc = _mm256_fmadd_ps(wv_row, normed, v_acc);
639 float q_sum = hsum256_ps(q_acc);
640 float k_sum = hsum256_ps(k_acc);
641 float v_sum = hsum256_ps(v_acc);
643 for (; i < embed_dim; i++) {
644 float normed = x[i] * rms_weight[i] * scale;
645 q_sum += wq[j * embed_dim + i] * normed;
646 k_sum += wk[j * embed_dim + i] * normed;
647 v_sum += wv[j * embed_dim + i] * normed;
659 for (
int j = kv_dim; j < q_dim; j++) {
660 __m256 q_acc = _mm256_setzero_ps();
663 for (; i + 7 < embed_dim; i += 8) {
664 __m256 vx = _mm256_loadu_ps(x + i);
665 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
666 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
668 __m256 wq_row = _mm256_loadu_ps(wq + j * embed_dim + i);
669 q_acc = _mm256_fmadd_ps(wq_row, normed, q_acc);
672 float q_sum = hsum256_ps(q_acc);
673 for (; i < embed_dim; i++) {
674 float normed = x[i] * rms_weight[i] * scale;
675 q_sum += wq[j * embed_dim + i] * normed;
683 for (
int j = 0; j < kv_dim; j++) {
684 float q_sum = 0.0f, k_sum = 0.0f, v_sum = 0.0f;
685 for (
int i = 0; i < embed_dim; i++) {
686 float normed = x[i] * rms_weight[i] * scale;
687 q_sum += wq[j * embed_dim + i] * normed;
688 k_sum += wk[j * embed_dim + i] * normed;
689 v_sum += wv[j * embed_dim + i] * normed;
695 for (
int j = kv_dim; j < q_dim; j++) {
697 for (
int i = 0; i < embed_dim; i++) {
698 float normed = x[i] * rms_weight[i] * scale;
699 q_sum += wq[j * embed_dim + i] * normed;
714 const float *rms_weight,
729 for (
int i = 0; i < embed_dim; i++) {
730 normed[i] = x[i] * rms_weight[i] * scale;
734 for (
int j = 0; j < q_dim; j++) {
736 for (
int i = 0; i < embed_dim; i++) {
737 sum += wq[j * embed_dim + i] * normed[i];
743 for (
int j = 0; j < kv_dim; j++) {
745 for (
int i = 0; i < embed_dim; i++) {
746 sum += wk[j * embed_dim + i] * normed[i];
752 for (
int j = 0; j < kv_dim; j++) {
754 for (
int i = 0; i < embed_dim; i++) {
755 sum += wv[j * embed_dim + i] * normed[i];
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
Quantization block structures for weight-only quantization.
static float compute_rms_scale(const float *x, int n, float eps)
void rmsnorm_qkv_fp32_fused_v3(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
void rmsnorm_qkv_separate_fp32(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *normed, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
void rmsnorm_qkv_fp32_fused_v2(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
void rmsnorm_qkv_fp32_fused(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
void rmsnorm_qkv_q4k_fused(const float *x, const float *rms_weight, const void *wq, const void *wk, const void *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)