339 __m256 vscale = _mm256_set1_ps(scale);
352 for (
int j = 0; j < q_dim; j += 8) {
354 __m256 acc0 = _mm256_setzero_ps();
355 __m256 acc1 = _mm256_setzero_ps();
356 __m256 acc2 = _mm256_setzero_ps();
357 __m256 acc3 = _mm256_setzero_ps();
358 __m256 acc4 = _mm256_setzero_ps();
359 __m256 acc5 = _mm256_setzero_ps();
360 __m256 acc6 = _mm256_setzero_ps();
361 __m256 acc7 = _mm256_setzero_ps();
365 for (; i + 7 < embed_dim; i += 8) {
367 __m256 vx = _mm256_loadu_ps(x + i);
368 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
369 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
375 __m256 w0 = _mm256_loadu_ps(wq + (j+0)*embed_dim + i);
376 acc0 = _mm256_fmadd_ps(w0, normed, acc0);
379 __m256 w1 = _mm256_loadu_ps(wq + (j+1)*embed_dim + i);
380 acc1 = _mm256_fmadd_ps(w1, normed, acc1);
383 __m256 w2 = _mm256_loadu_ps(wq + (j+2)*embed_dim + i);
384 acc2 = _mm256_fmadd_ps(w2, normed, acc2);
387 __m256 w3 = _mm256_loadu_ps(wq + (j+3)*embed_dim + i);
388 acc3 = _mm256_fmadd_ps(w3, normed, acc3);
391 __m256 w4 = _mm256_loadu_ps(wq + (j+4)*embed_dim + i);
392 acc4 = _mm256_fmadd_ps(w4, normed, acc4);
395 __m256 w5 = _mm256_loadu_ps(wq + (j+5)*embed_dim + i);
396 acc5 = _mm256_fmadd_ps(w5, normed, acc5);
399 __m256 w6 = _mm256_loadu_ps(wq + (j+6)*embed_dim + i);
400 acc6 = _mm256_fmadd_ps(w6, normed, acc6);
403 __m256 w7 = _mm256_loadu_ps(wq + (j+7)*embed_dim + i);
404 acc7 = _mm256_fmadd_ps(w7, normed, acc7);
409 for (; i < embed_dim; i++) {
410 float normed_scalar = x[i] * rms_weight[i] * scale;
411 if (j + 0 < q_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wq[(j+0)*embed_dim + i] * normed_scalar));
412 if (j + 1 < q_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wq[(j+1)*embed_dim + i] * normed_scalar));
413 if (j + 2 < q_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wq[(j+2)*embed_dim + i] * normed_scalar));
414 if (j + 3 < q_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wq[(j+3)*embed_dim + i] * normed_scalar));
415 if (j + 4 < q_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wq[(j+4)*embed_dim + i] * normed_scalar));
416 if (j + 5 < q_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wq[(j+5)*embed_dim + i] * normed_scalar));
417 if (j + 6 < q_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wq[(j+6)*embed_dim + i] * normed_scalar));
418 if (j + 7 < q_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wq[(j+7)*embed_dim + i] * normed_scalar));
422 if (j + 0 < q_dim) q_out[j+0] = hsum256_ps(acc0);
423 if (j + 1 < q_dim) q_out[j+1] = hsum256_ps(acc1);
424 if (j + 2 < q_dim) q_out[j+2] = hsum256_ps(acc2);
425 if (j + 3 < q_dim) q_out[j+3] = hsum256_ps(acc3);
426 if (j + 4 < q_dim) q_out[j+4] = hsum256_ps(acc4);
427 if (j + 5 < q_dim) q_out[j+5] = hsum256_ps(acc5);
428 if (j + 6 < q_dim) q_out[j+6] = hsum256_ps(acc6);
429 if (j + 7 < q_dim) q_out[j+7] = hsum256_ps(acc7);
436 for (
int j = 0; j < kv_dim; j += 8) {
437 __m256 acc0 = _mm256_setzero_ps();
438 __m256 acc1 = _mm256_setzero_ps();
439 __m256 acc2 = _mm256_setzero_ps();
440 __m256 acc3 = _mm256_setzero_ps();
441 __m256 acc4 = _mm256_setzero_ps();
442 __m256 acc5 = _mm256_setzero_ps();
443 __m256 acc6 = _mm256_setzero_ps();
444 __m256 acc7 = _mm256_setzero_ps();
447 for (; i + 7 < embed_dim; i += 8) {
448 __m256 vx = _mm256_loadu_ps(x + i);
449 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
450 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
452 if (j + 0 < kv_dim) acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+0)*embed_dim + i), normed, acc0);
453 if (j + 1 < kv_dim) acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+1)*embed_dim + i), normed, acc1);
454 if (j + 2 < kv_dim) acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+2)*embed_dim + i), normed, acc2);
455 if (j + 3 < kv_dim) acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+3)*embed_dim + i), normed, acc3);
456 if (j + 4 < kv_dim) acc4 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+4)*embed_dim + i), normed, acc4);
457 if (j + 5 < kv_dim) acc5 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+5)*embed_dim + i), normed, acc5);
458 if (j + 6 < kv_dim) acc6 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+6)*embed_dim + i), normed, acc6);
459 if (j + 7 < kv_dim) acc7 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+7)*embed_dim + i), normed, acc7);
462 for (; i < embed_dim; i++) {
463 float normed_scalar = x[i] * rms_weight[i] * scale;
464 if (j + 0 < kv_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wk[(j+0)*embed_dim + i] * normed_scalar));
465 if (j + 1 < kv_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wk[(j+1)*embed_dim + i] * normed_scalar));
466 if (j + 2 < kv_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wk[(j+2)*embed_dim + i] * normed_scalar));
467 if (j + 3 < kv_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wk[(j+3)*embed_dim + i] * normed_scalar));
468 if (j + 4 < kv_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wk[(j+4)*embed_dim + i] * normed_scalar));
469 if (j + 5 < kv_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wk[(j+5)*embed_dim + i] * normed_scalar));
470 if (j + 6 < kv_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wk[(j+6)*embed_dim + i] * normed_scalar));
471 if (j + 7 < kv_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wk[(j+7)*embed_dim + i] * normed_scalar));
474 if (j + 0 < kv_dim) k_out[j+0] = hsum256_ps(acc0);
475 if (j + 1 < kv_dim) k_out[j+1] = hsum256_ps(acc1);
476 if (j + 2 < kv_dim) k_out[j+2] = hsum256_ps(acc2);
477 if (j + 3 < kv_dim) k_out[j+3] = hsum256_ps(acc3);
478 if (j + 4 < kv_dim) k_out[j+4] = hsum256_ps(acc4);
479 if (j + 5 < kv_dim) k_out[j+5] = hsum256_ps(acc5);
480 if (j + 6 < kv_dim) k_out[j+6] = hsum256_ps(acc6);
481 if (j + 7 < kv_dim) k_out[j+7] = hsum256_ps(acc7);
488 for (
int j = 0; j < kv_dim; j += 8) {
489 __m256 acc0 = _mm256_setzero_ps();
490 __m256 acc1 = _mm256_setzero_ps();
491 __m256 acc2 = _mm256_setzero_ps();
492 __m256 acc3 = _mm256_setzero_ps();
493 __m256 acc4 = _mm256_setzero_ps();
494 __m256 acc5 = _mm256_setzero_ps();
495 __m256 acc6 = _mm256_setzero_ps();
496 __m256 acc7 = _mm256_setzero_ps();
499 for (; i + 7 < embed_dim; i += 8) {
500 __m256 vx = _mm256_loadu_ps(x + i);
501 __m256 vrms = _mm256_loadu_ps(rms_weight + i);
502 __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
504 if (j + 0 < kv_dim) acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+0)*embed_dim + i), normed, acc0);
505 if (j + 1 < kv_dim) acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+1)*embed_dim + i), normed, acc1);
506 if (j + 2 < kv_dim) acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+2)*embed_dim + i), normed, acc2);
507 if (j + 3 < kv_dim) acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+3)*embed_dim + i), normed, acc3);
508 if (j + 4 < kv_dim) acc4 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+4)*embed_dim + i), normed, acc4);
509 if (j + 5 < kv_dim) acc5 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+5)*embed_dim + i), normed, acc5);
510 if (j + 6 < kv_dim) acc6 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+6)*embed_dim + i), normed, acc6);
511 if (j + 7 < kv_dim) acc7 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+7)*embed_dim + i), normed, acc7);
514 for (; i < embed_dim; i++) {
515 float normed_scalar = x[i] * rms_weight[i] * scale;
516 if (j + 0 < kv_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wv[(j+0)*embed_dim + i] * normed_scalar));
517 if (j + 1 < kv_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wv[(j+1)*embed_dim + i] * normed_scalar));
518 if (j + 2 < kv_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wv[(j+2)*embed_dim + i] * normed_scalar));
519 if (j + 3 < kv_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wv[(j+3)*embed_dim + i] * normed_scalar));
520 if (j + 4 < kv_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wv[(j+4)*embed_dim + i] * normed_scalar));
521 if (j + 5 < kv_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wv[(j+5)*embed_dim + i] * normed_scalar));
522 if (j + 6 < kv_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wv[(j+6)*embed_dim + i] * normed_scalar));
523 if (j + 7 < kv_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wv[(j+7)*embed_dim + i] * normed_scalar));
526 if (j + 0 < kv_dim) v_out[j+0] = hsum256_ps(acc0);
527 if (j + 1 < kv_dim) v_out[j+1] = hsum256_ps(acc1);
528 if (j + 2 < kv_dim) v_out[j+2] = hsum256_ps(acc2);
529 if (j + 3 < kv_dim) v_out[j+3] = hsum256_ps(acc3);
530 if (j + 4 < kv_dim) v_out[j+4] = hsum256_ps(acc4);
531 if (j + 5 < kv_dim) v_out[j+5] = hsum256_ps(acc5);
532 if (j + 6 < kv_dim) v_out[j+6] = hsum256_ps(acc6);
533 if (j + 7 < kv_dim) v_out[j+7] = hsum256_ps(acc7);
538 for (
int j = 0; j < q_dim; j++) {
540 for (
int i = 0; i < embed_dim; i++) {
541 float normed = x[i] * rms_weight[i] * scale;
542 sum += wq[j * embed_dim + i] * normed;
546 for (
int j = 0; j < kv_dim; j++) {
548 for (
int i = 0; i < embed_dim; i++) {
549 float normed = x[i] * rms_weight[i] * scale;
550 sum += wk[j * embed_dim + i] * normed;
554 for (
int j = 0; j < kv_dim; j++) {
556 for (
int i = 0; i < embed_dim; i++) {
557 float normed = x[i] * rms_weight[i] * scale;
558 sum += wv[j * embed_dim + i] * normed;