35 #if QWEN2_0_5B_DECODE_DTYPE_BYTES != 4
36 #error "qwen2_0.5b_decode: v6 codegen currently supports fp32 only. Use --dtype=fp32."
50 if (!a || !b || !out) {
53 for (
int t = 0; t < tokens; ++t) {
54 const float *pa = a + (size_t)t * (
size_t)aligned_embed_dim;
55 const float *pb = b + (size_t)t * (
size_t)aligned_embed_dim;
56 float *pc = out + (size_t)t * (
size_t)aligned_embed_dim;
57 for (
int d = 0; d < aligned_embed_dim; ++d) {
58 pc[d] = pa[d] + pb[d];
71 uint64_t weight_bytes;
72 uint64_t activation_bytes;
78 uint32_t canary_count;
92 model->
base = mmap(NULL, total,
93 PROT_READ | PROT_WRITE,
94 MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
96 if (model->
base == MAP_FAILED) {
97 model->
base = mmap(NULL, total,
98 PROT_READ | PROT_WRITE,
99 MAP_PRIVATE | MAP_ANONYMOUS,
102 if (model->
base == MAP_FAILED) {
103 perror(
"mmap failed");
107 model->
base = aligned_alloc(64, total);
109 perror(
"aligned_alloc failed");
142 if (!model || !model->
base)
return;
158 for (
int j = 0; j < 4; j++) {
160 fprintf(stderr,
"CANARY CORRUPTION: %s at offset 0x%lX\n",
177 int bytes = elems * elem_bytes;
178 int aligned = (bytes + align_bytes - 1) / align_bytes * align_bytes;
179 return aligned / elem_bytes;
189 const float theta = 1000000.0f;
194 for (
int pos = 0; pos < T; pos++) {
195 for (
int i = 0; i < D; i++) {
196 float freq = 1.0f / powf(theta, (
float)(2 * i) / (
float)(D * 2));
197 float angle = (float)pos * freq;
198 cos_ptr[pos * D + i] = cosf(angle);
199 sin_ptr[pos * D + i] = sinf(angle);
214 int aligned_embed_dim,
215 int aligned_head_dim,
216 int aligned_intermediate_dim,
217 int aligned_context_window
244 const float *BQ = NULL;
245 const float *BK = NULL;
246 const float *BV = NULL;
247 const float *BO = NULL;
248 const float *B1 = NULL;
249 const float *B2 = NULL;
257 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
258 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
259 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
273 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
274 for (
int h = 0; h < H; ++h) {
275 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
276 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
277 float *q_h = q + (size_t)h * q_head_stride;
278 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
283 const uint8_t *WK_bytes = (
const uint8_t *)WK;
284 for (
int h = 0; h < H_kv; ++h) {
285 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
286 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
287 float *k_h = k + (size_t)h * kv_head_stride;
288 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
293 const uint8_t *WV_bytes = (
const uint8_t *)WV;
294 for (
int h = 0; h < H_kv; ++h) {
295 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
296 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
297 float *v_h = v + (size_t)h * kv_head_stride;
298 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
325 const int K = H * aligned_head_dim;
326 if (K != aligned_embed_dim) {
329 const float *proj_in = attn_out;
334 for (
int t = 0; t < num_tokens; ++t) {
335 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
336 for (
int h = 0; h < H; ++h) {
337 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
338 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
340 (
size_t)aligned_head_dim *
sizeof(
float));
343 proj_in = proj_scratch;
345 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
361 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
362 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
363 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
375 int aligned_embed_dim,
376 int aligned_head_dim,
377 int aligned_intermediate_dim,
378 int aligned_context_window
405 const float *BQ = NULL;
406 const float *BK = NULL;
407 const float *BV = NULL;
408 const float *BO = NULL;
409 const float *B1 = NULL;
410 const float *B2 = NULL;
418 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
419 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
420 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
434 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
435 for (
int h = 0; h < H; ++h) {
436 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
437 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
438 float *q_h = q + (size_t)h * q_head_stride;
439 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
444 const uint8_t *WK_bytes = (
const uint8_t *)WK;
445 for (
int h = 0; h < H_kv; ++h) {
446 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
447 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
448 float *k_h = k + (size_t)h * kv_head_stride;
449 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
454 const uint8_t *WV_bytes = (
const uint8_t *)WV;
455 for (
int h = 0; h < H_kv; ++h) {
456 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
457 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
458 float *v_h = v + (size_t)h * kv_head_stride;
459 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
486 const int K = H * aligned_head_dim;
487 if (K != aligned_embed_dim) {
490 const float *proj_in = attn_out;
495 for (
int t = 0; t < num_tokens; ++t) {
496 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
497 for (
int h = 0; h < H; ++h) {
498 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
499 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
501 (
size_t)aligned_head_dim *
sizeof(
float));
504 proj_in = proj_scratch;
506 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
522 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
523 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
524 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
536 int aligned_embed_dim,
537 int aligned_head_dim,
538 int aligned_intermediate_dim,
539 int aligned_context_window
566 const float *BQ = NULL;
567 const float *BK = NULL;
568 const float *BV = NULL;
569 const float *BO = NULL;
570 const float *B1 = NULL;
571 const float *B2 = NULL;
579 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
580 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
581 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
595 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
596 for (
int h = 0; h < H; ++h) {
597 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
598 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
599 float *q_h = q + (size_t)h * q_head_stride;
600 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
605 const uint8_t *WK_bytes = (
const uint8_t *)WK;
606 for (
int h = 0; h < H_kv; ++h) {
607 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
608 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
609 float *k_h = k + (size_t)h * kv_head_stride;
610 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
615 const uint8_t *WV_bytes = (
const uint8_t *)WV;
616 for (
int h = 0; h < H_kv; ++h) {
617 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
618 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
619 float *v_h = v + (size_t)h * kv_head_stride;
620 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
647 const int K = H * aligned_head_dim;
648 if (K != aligned_embed_dim) {
651 const float *proj_in = attn_out;
656 for (
int t = 0; t < num_tokens; ++t) {
657 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
658 for (
int h = 0; h < H; ++h) {
659 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
660 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
662 (
size_t)aligned_head_dim *
sizeof(
float));
665 proj_in = proj_scratch;
667 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
683 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
684 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
685 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
697 int aligned_embed_dim,
698 int aligned_head_dim,
699 int aligned_intermediate_dim,
700 int aligned_context_window
727 const float *BQ = NULL;
728 const float *BK = NULL;
729 const float *BV = NULL;
730 const float *BO = NULL;
731 const float *B1 = NULL;
732 const float *B2 = NULL;
740 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
741 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
742 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
756 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
757 for (
int h = 0; h < H; ++h) {
758 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
759 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
760 float *q_h = q + (size_t)h * q_head_stride;
761 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
766 const uint8_t *WK_bytes = (
const uint8_t *)WK;
767 for (
int h = 0; h < H_kv; ++h) {
768 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
769 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
770 float *k_h = k + (size_t)h * kv_head_stride;
771 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
776 const uint8_t *WV_bytes = (
const uint8_t *)WV;
777 for (
int h = 0; h < H_kv; ++h) {
778 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
779 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
780 float *v_h = v + (size_t)h * kv_head_stride;
781 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
808 const int K = H * aligned_head_dim;
809 if (K != aligned_embed_dim) {
812 const float *proj_in = attn_out;
817 for (
int t = 0; t < num_tokens; ++t) {
818 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
819 for (
int h = 0; h < H; ++h) {
820 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
821 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
823 (
size_t)aligned_head_dim *
sizeof(
float));
826 proj_in = proj_scratch;
828 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
844 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
845 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
846 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
858 int aligned_embed_dim,
859 int aligned_head_dim,
860 int aligned_intermediate_dim,
861 int aligned_context_window
888 const float *BQ = NULL;
889 const float *BK = NULL;
890 const float *BV = NULL;
891 const float *BO = NULL;
892 const float *B1 = NULL;
893 const float *B2 = NULL;
901 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
902 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
903 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
917 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
918 for (
int h = 0; h < H; ++h) {
919 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
920 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
921 float *q_h = q + (size_t)h * q_head_stride;
922 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
927 const uint8_t *WK_bytes = (
const uint8_t *)WK;
928 for (
int h = 0; h < H_kv; ++h) {
929 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
930 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
931 float *k_h = k + (size_t)h * kv_head_stride;
932 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
937 const uint8_t *WV_bytes = (
const uint8_t *)WV;
938 for (
int h = 0; h < H_kv; ++h) {
939 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
940 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
941 float *v_h = v + (size_t)h * kv_head_stride;
942 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
969 const int K = H * aligned_head_dim;
970 if (K != aligned_embed_dim) {
973 const float *proj_in = attn_out;
978 for (
int t = 0; t < num_tokens; ++t) {
979 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
980 for (
int h = 0; h < H; ++h) {
981 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
982 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
984 (
size_t)aligned_head_dim *
sizeof(
float));
987 proj_in = proj_scratch;
989 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1005 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1006 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1007 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1019 int aligned_embed_dim,
1020 int aligned_head_dim,
1021 int aligned_intermediate_dim,
1022 int aligned_context_window
1049 const float *BQ = NULL;
1050 const float *BK = NULL;
1051 const float *BV = NULL;
1052 const float *BO = NULL;
1053 const float *B1 = NULL;
1054 const float *B2 = NULL;
1062 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1063 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1064 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1078 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1079 for (
int h = 0; h < H; ++h) {
1080 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1081 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1082 float *q_h = q + (size_t)h * q_head_stride;
1083 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1088 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1089 for (
int h = 0; h < H_kv; ++h) {
1090 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1091 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1092 float *k_h = k + (size_t)h * kv_head_stride;
1093 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1098 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1099 for (
int h = 0; h < H_kv; ++h) {
1100 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1101 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1102 float *v_h = v + (size_t)h * kv_head_stride;
1103 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1130 const int K = H * aligned_head_dim;
1131 if (K != aligned_embed_dim) {
1134 const float *proj_in = attn_out;
1136 if (!proj_scratch) {
1139 for (
int t = 0; t < num_tokens; ++t) {
1140 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1141 for (
int h = 0; h < H; ++h) {
1142 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1143 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1145 (
size_t)aligned_head_dim *
sizeof(
float));
1148 proj_in = proj_scratch;
1150 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1166 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1167 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1168 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1180 int aligned_embed_dim,
1181 int aligned_head_dim,
1182 int aligned_intermediate_dim,
1183 int aligned_context_window
1210 const float *BQ = NULL;
1211 const float *BK = NULL;
1212 const float *BV = NULL;
1213 const float *BO = NULL;
1214 const float *B1 = NULL;
1215 const float *B2 = NULL;
1223 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1224 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1225 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1239 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1240 for (
int h = 0; h < H; ++h) {
1241 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1242 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1243 float *q_h = q + (size_t)h * q_head_stride;
1244 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1249 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1250 for (
int h = 0; h < H_kv; ++h) {
1251 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1252 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1253 float *k_h = k + (size_t)h * kv_head_stride;
1254 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1259 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1260 for (
int h = 0; h < H_kv; ++h) {
1261 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1262 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1263 float *v_h = v + (size_t)h * kv_head_stride;
1264 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1291 const int K = H * aligned_head_dim;
1292 if (K != aligned_embed_dim) {
1295 const float *proj_in = attn_out;
1297 if (!proj_scratch) {
1300 for (
int t = 0; t < num_tokens; ++t) {
1301 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1302 for (
int h = 0; h < H; ++h) {
1303 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1304 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1306 (
size_t)aligned_head_dim *
sizeof(
float));
1309 proj_in = proj_scratch;
1311 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1327 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1328 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1329 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1341 int aligned_embed_dim,
1342 int aligned_head_dim,
1343 int aligned_intermediate_dim,
1344 int aligned_context_window
1371 const float *BQ = NULL;
1372 const float *BK = NULL;
1373 const float *BV = NULL;
1374 const float *BO = NULL;
1375 const float *B1 = NULL;
1376 const float *B2 = NULL;
1384 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1385 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1386 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1400 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1401 for (
int h = 0; h < H; ++h) {
1402 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1403 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1404 float *q_h = q + (size_t)h * q_head_stride;
1405 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1410 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1411 for (
int h = 0; h < H_kv; ++h) {
1412 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1413 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1414 float *k_h = k + (size_t)h * kv_head_stride;
1415 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1420 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1421 for (
int h = 0; h < H_kv; ++h) {
1422 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1423 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1424 float *v_h = v + (size_t)h * kv_head_stride;
1425 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1452 const int K = H * aligned_head_dim;
1453 if (K != aligned_embed_dim) {
1456 const float *proj_in = attn_out;
1458 if (!proj_scratch) {
1461 for (
int t = 0; t < num_tokens; ++t) {
1462 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1463 for (
int h = 0; h < H; ++h) {
1464 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1465 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1467 (
size_t)aligned_head_dim *
sizeof(
float));
1470 proj_in = proj_scratch;
1472 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1488 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1489 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1490 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1502 int aligned_embed_dim,
1503 int aligned_head_dim,
1504 int aligned_intermediate_dim,
1505 int aligned_context_window
1532 const float *BQ = NULL;
1533 const float *BK = NULL;
1534 const float *BV = NULL;
1535 const float *BO = NULL;
1536 const float *B1 = NULL;
1537 const float *B2 = NULL;
1545 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1546 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1547 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1561 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1562 for (
int h = 0; h < H; ++h) {
1563 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1564 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1565 float *q_h = q + (size_t)h * q_head_stride;
1566 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1571 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1572 for (
int h = 0; h < H_kv; ++h) {
1573 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1574 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1575 float *k_h = k + (size_t)h * kv_head_stride;
1576 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1581 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1582 for (
int h = 0; h < H_kv; ++h) {
1583 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1584 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1585 float *v_h = v + (size_t)h * kv_head_stride;
1586 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1613 const int K = H * aligned_head_dim;
1614 if (K != aligned_embed_dim) {
1617 const float *proj_in = attn_out;
1619 if (!proj_scratch) {
1622 for (
int t = 0; t < num_tokens; ++t) {
1623 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1624 for (
int h = 0; h < H; ++h) {
1625 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1626 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1628 (
size_t)aligned_head_dim *
sizeof(
float));
1631 proj_in = proj_scratch;
1633 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1649 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1650 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1651 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1663 int aligned_embed_dim,
1664 int aligned_head_dim,
1665 int aligned_intermediate_dim,
1666 int aligned_context_window
1693 const float *BQ = NULL;
1694 const float *BK = NULL;
1695 const float *BV = NULL;
1696 const float *BO = NULL;
1697 const float *B1 = NULL;
1698 const float *B2 = NULL;
1706 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1707 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1708 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1722 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1723 for (
int h = 0; h < H; ++h) {
1724 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1725 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1726 float *q_h = q + (size_t)h * q_head_stride;
1727 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1732 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1733 for (
int h = 0; h < H_kv; ++h) {
1734 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1735 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1736 float *k_h = k + (size_t)h * kv_head_stride;
1737 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1742 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1743 for (
int h = 0; h < H_kv; ++h) {
1744 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1745 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1746 float *v_h = v + (size_t)h * kv_head_stride;
1747 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1774 const int K = H * aligned_head_dim;
1775 if (K != aligned_embed_dim) {
1778 const float *proj_in = attn_out;
1780 if (!proj_scratch) {
1783 for (
int t = 0; t < num_tokens; ++t) {
1784 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1785 for (
int h = 0; h < H; ++h) {
1786 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1787 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1789 (
size_t)aligned_head_dim *
sizeof(
float));
1792 proj_in = proj_scratch;
1794 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1810 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1811 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1812 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1824 int aligned_embed_dim,
1825 int aligned_head_dim,
1826 int aligned_intermediate_dim,
1827 int aligned_context_window
1854 const float *BQ = NULL;
1855 const float *BK = NULL;
1856 const float *BV = NULL;
1857 const float *BO = NULL;
1858 const float *B1 = NULL;
1859 const float *B2 = NULL;
1867 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1868 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1869 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1883 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1884 for (
int h = 0; h < H; ++h) {
1885 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1886 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1887 float *q_h = q + (size_t)h * q_head_stride;
1888 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1893 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1894 for (
int h = 0; h < H_kv; ++h) {
1895 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1896 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1897 float *k_h = k + (size_t)h * kv_head_stride;
1898 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1903 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1904 for (
int h = 0; h < H_kv; ++h) {
1905 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1906 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1907 float *v_h = v + (size_t)h * kv_head_stride;
1908 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1935 const int K = H * aligned_head_dim;
1936 if (K != aligned_embed_dim) {
1939 const float *proj_in = attn_out;
1941 if (!proj_scratch) {
1944 for (
int t = 0; t < num_tokens; ++t) {
1945 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1946 for (
int h = 0; h < H; ++h) {
1947 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1948 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1950 (
size_t)aligned_head_dim *
sizeof(
float));
1953 proj_in = proj_scratch;
1955 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1971 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1972 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1973 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1985 int aligned_embed_dim,
1986 int aligned_head_dim,
1987 int aligned_intermediate_dim,
1988 int aligned_context_window
2015 const float *BQ = NULL;
2016 const float *BK = NULL;
2017 const float *BV = NULL;
2018 const float *BO = NULL;
2019 const float *B1 = NULL;
2020 const float *B2 = NULL;
2028 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2029 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2030 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2044 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2045 for (
int h = 0; h < H; ++h) {
2046 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2047 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2048 float *q_h = q + (size_t)h * q_head_stride;
2049 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2054 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2055 for (
int h = 0; h < H_kv; ++h) {
2056 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2057 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2058 float *k_h = k + (size_t)h * kv_head_stride;
2059 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2064 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2065 for (
int h = 0; h < H_kv; ++h) {
2066 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2067 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2068 float *v_h = v + (size_t)h * kv_head_stride;
2069 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2096 const int K = H * aligned_head_dim;
2097 if (K != aligned_embed_dim) {
2100 const float *proj_in = attn_out;
2102 if (!proj_scratch) {
2105 for (
int t = 0; t < num_tokens; ++t) {
2106 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2107 for (
int h = 0; h < H; ++h) {
2108 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2109 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2111 (
size_t)aligned_head_dim *
sizeof(
float));
2114 proj_in = proj_scratch;
2116 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2132 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2133 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2134 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2146 int aligned_embed_dim,
2147 int aligned_head_dim,
2148 int aligned_intermediate_dim,
2149 int aligned_context_window
2176 const float *BQ = NULL;
2177 const float *BK = NULL;
2178 const float *BV = NULL;
2179 const float *BO = NULL;
2180 const float *B1 = NULL;
2181 const float *B2 = NULL;
2189 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2190 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2191 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2205 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2206 for (
int h = 0; h < H; ++h) {
2207 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2208 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2209 float *q_h = q + (size_t)h * q_head_stride;
2210 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2215 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2216 for (
int h = 0; h < H_kv; ++h) {
2217 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2218 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2219 float *k_h = k + (size_t)h * kv_head_stride;
2220 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2225 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2226 for (
int h = 0; h < H_kv; ++h) {
2227 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2228 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2229 float *v_h = v + (size_t)h * kv_head_stride;
2230 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2257 const int K = H * aligned_head_dim;
2258 if (K != aligned_embed_dim) {
2261 const float *proj_in = attn_out;
2263 if (!proj_scratch) {
2266 for (
int t = 0; t < num_tokens; ++t) {
2267 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2268 for (
int h = 0; h < H; ++h) {
2269 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2270 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2272 (
size_t)aligned_head_dim *
sizeof(
float));
2275 proj_in = proj_scratch;
2277 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2293 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2294 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2295 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2307 int aligned_embed_dim,
2308 int aligned_head_dim,
2309 int aligned_intermediate_dim,
2310 int aligned_context_window
2337 const float *BQ = NULL;
2338 const float *BK = NULL;
2339 const float *BV = NULL;
2340 const float *BO = NULL;
2341 const float *B1 = NULL;
2342 const float *B2 = NULL;
2350 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2351 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2352 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2366 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2367 for (
int h = 0; h < H; ++h) {
2368 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2369 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2370 float *q_h = q + (size_t)h * q_head_stride;
2371 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2376 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2377 for (
int h = 0; h < H_kv; ++h) {
2378 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2379 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2380 float *k_h = k + (size_t)h * kv_head_stride;
2381 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2386 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2387 for (
int h = 0; h < H_kv; ++h) {
2388 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2389 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2390 float *v_h = v + (size_t)h * kv_head_stride;
2391 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2418 const int K = H * aligned_head_dim;
2419 if (K != aligned_embed_dim) {
2422 const float *proj_in = attn_out;
2424 if (!proj_scratch) {
2427 for (
int t = 0; t < num_tokens; ++t) {
2428 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2429 for (
int h = 0; h < H; ++h) {
2430 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2431 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2433 (
size_t)aligned_head_dim *
sizeof(
float));
2436 proj_in = proj_scratch;
2438 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2454 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2455 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2456 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2468 int aligned_embed_dim,
2469 int aligned_head_dim,
2470 int aligned_intermediate_dim,
2471 int aligned_context_window
2498 const float *BQ = NULL;
2499 const float *BK = NULL;
2500 const float *BV = NULL;
2501 const float *BO = NULL;
2502 const float *B1 = NULL;
2503 const float *B2 = NULL;
2511 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2512 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2513 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2527 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2528 for (
int h = 0; h < H; ++h) {
2529 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2530 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2531 float *q_h = q + (size_t)h * q_head_stride;
2532 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2537 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2538 for (
int h = 0; h < H_kv; ++h) {
2539 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2540 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2541 float *k_h = k + (size_t)h * kv_head_stride;
2542 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2547 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2548 for (
int h = 0; h < H_kv; ++h) {
2549 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2550 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2551 float *v_h = v + (size_t)h * kv_head_stride;
2552 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2579 const int K = H * aligned_head_dim;
2580 if (K != aligned_embed_dim) {
2583 const float *proj_in = attn_out;
2585 if (!proj_scratch) {
2588 for (
int t = 0; t < num_tokens; ++t) {
2589 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2590 for (
int h = 0; h < H; ++h) {
2591 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2592 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2594 (
size_t)aligned_head_dim *
sizeof(
float));
2597 proj_in = proj_scratch;
2599 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2615 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2616 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2617 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2629 int aligned_embed_dim,
2630 int aligned_head_dim,
2631 int aligned_intermediate_dim,
2632 int aligned_context_window
2659 const float *BQ = NULL;
2660 const float *BK = NULL;
2661 const float *BV = NULL;
2662 const float *BO = NULL;
2663 const float *B1 = NULL;
2664 const float *B2 = NULL;
2672 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2673 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2674 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2688 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2689 for (
int h = 0; h < H; ++h) {
2690 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2691 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2692 float *q_h = q + (size_t)h * q_head_stride;
2693 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2698 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2699 for (
int h = 0; h < H_kv; ++h) {
2700 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2701 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2702 float *k_h = k + (size_t)h * kv_head_stride;
2703 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2708 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2709 for (
int h = 0; h < H_kv; ++h) {
2710 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2711 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2712 float *v_h = v + (size_t)h * kv_head_stride;
2713 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2740 const int K = H * aligned_head_dim;
2741 if (K != aligned_embed_dim) {
2744 const float *proj_in = attn_out;
2746 if (!proj_scratch) {
2749 for (
int t = 0; t < num_tokens; ++t) {
2750 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2751 for (
int h = 0; h < H; ++h) {
2752 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2753 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2755 (
size_t)aligned_head_dim *
sizeof(
float));
2758 proj_in = proj_scratch;
2760 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2776 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2777 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2778 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2790 int aligned_embed_dim,
2791 int aligned_head_dim,
2792 int aligned_intermediate_dim,
2793 int aligned_context_window
2820 const float *BQ = NULL;
2821 const float *BK = NULL;
2822 const float *BV = NULL;
2823 const float *BO = NULL;
2824 const float *B1 = NULL;
2825 const float *B2 = NULL;
2833 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2834 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2835 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2849 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2850 for (
int h = 0; h < H; ++h) {
2851 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2852 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2853 float *q_h = q + (size_t)h * q_head_stride;
2854 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2859 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2860 for (
int h = 0; h < H_kv; ++h) {
2861 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2862 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2863 float *k_h = k + (size_t)h * kv_head_stride;
2864 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2869 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2870 for (
int h = 0; h < H_kv; ++h) {
2871 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2872 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2873 float *v_h = v + (size_t)h * kv_head_stride;
2874 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2901 const int K = H * aligned_head_dim;
2902 if (K != aligned_embed_dim) {
2905 const float *proj_in = attn_out;
2907 if (!proj_scratch) {
2910 for (
int t = 0; t < num_tokens; ++t) {
2911 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2912 for (
int h = 0; h < H; ++h) {
2913 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2914 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2916 (
size_t)aligned_head_dim *
sizeof(
float));
2919 proj_in = proj_scratch;
2921 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2937 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2938 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2939 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2951 int aligned_embed_dim,
2952 int aligned_head_dim,
2953 int aligned_intermediate_dim,
2954 int aligned_context_window
2981 const float *BQ = NULL;
2982 const float *BK = NULL;
2983 const float *BV = NULL;
2984 const float *BO = NULL;
2985 const float *B1 = NULL;
2986 const float *B2 = NULL;
2994 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2995 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2996 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3010 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3011 for (
int h = 0; h < H; ++h) {
3012 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3013 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3014 float *q_h = q + (size_t)h * q_head_stride;
3015 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3020 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3021 for (
int h = 0; h < H_kv; ++h) {
3022 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3023 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3024 float *k_h = k + (size_t)h * kv_head_stride;
3025 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3030 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3031 for (
int h = 0; h < H_kv; ++h) {
3032 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3033 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3034 float *v_h = v + (size_t)h * kv_head_stride;
3035 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3062 const int K = H * aligned_head_dim;
3063 if (K != aligned_embed_dim) {
3066 const float *proj_in = attn_out;
3068 if (!proj_scratch) {
3071 for (
int t = 0; t < num_tokens; ++t) {
3072 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3073 for (
int h = 0; h < H; ++h) {
3074 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3075 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3077 (
size_t)aligned_head_dim *
sizeof(
float));
3080 proj_in = proj_scratch;
3082 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3098 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3099 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3100 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3112 int aligned_embed_dim,
3113 int aligned_head_dim,
3114 int aligned_intermediate_dim,
3115 int aligned_context_window
3142 const float *BQ = NULL;
3143 const float *BK = NULL;
3144 const float *BV = NULL;
3145 const float *BO = NULL;
3146 const float *B1 = NULL;
3147 const float *B2 = NULL;
3155 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3156 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3157 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3171 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3172 for (
int h = 0; h < H; ++h) {
3173 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3174 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3175 float *q_h = q + (size_t)h * q_head_stride;
3176 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3181 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3182 for (
int h = 0; h < H_kv; ++h) {
3183 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3184 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3185 float *k_h = k + (size_t)h * kv_head_stride;
3186 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3191 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3192 for (
int h = 0; h < H_kv; ++h) {
3193 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3194 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3195 float *v_h = v + (size_t)h * kv_head_stride;
3196 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3223 const int K = H * aligned_head_dim;
3224 if (K != aligned_embed_dim) {
3227 const float *proj_in = attn_out;
3229 if (!proj_scratch) {
3232 for (
int t = 0; t < num_tokens; ++t) {
3233 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3234 for (
int h = 0; h < H; ++h) {
3235 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3236 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3238 (
size_t)aligned_head_dim *
sizeof(
float));
3241 proj_in = proj_scratch;
3243 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3259 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3260 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3261 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3273 int aligned_embed_dim,
3274 int aligned_head_dim,
3275 int aligned_intermediate_dim,
3276 int aligned_context_window
3303 const float *BQ = NULL;
3304 const float *BK = NULL;
3305 const float *BV = NULL;
3306 const float *BO = NULL;
3307 const float *B1 = NULL;
3308 const float *B2 = NULL;
3316 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3317 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3318 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3332 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3333 for (
int h = 0; h < H; ++h) {
3334 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3335 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3336 float *q_h = q + (size_t)h * q_head_stride;
3337 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3342 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3343 for (
int h = 0; h < H_kv; ++h) {
3344 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3345 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3346 float *k_h = k + (size_t)h * kv_head_stride;
3347 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3352 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3353 for (
int h = 0; h < H_kv; ++h) {
3354 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3355 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3356 float *v_h = v + (size_t)h * kv_head_stride;
3357 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3384 const int K = H * aligned_head_dim;
3385 if (K != aligned_embed_dim) {
3388 const float *proj_in = attn_out;
3390 if (!proj_scratch) {
3393 for (
int t = 0; t < num_tokens; ++t) {
3394 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3395 for (
int h = 0; h < H; ++h) {
3396 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3397 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3399 (
size_t)aligned_head_dim *
sizeof(
float));
3402 proj_in = proj_scratch;
3404 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3420 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3421 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3422 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3434 int aligned_embed_dim,
3435 int aligned_head_dim,
3436 int aligned_intermediate_dim,
3437 int aligned_context_window
3464 const float *BQ = NULL;
3465 const float *BK = NULL;
3466 const float *BV = NULL;
3467 const float *BO = NULL;
3468 const float *B1 = NULL;
3469 const float *B2 = NULL;
3477 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3478 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3479 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3493 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3494 for (
int h = 0; h < H; ++h) {
3495 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3496 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3497 float *q_h = q + (size_t)h * q_head_stride;
3498 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3503 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3504 for (
int h = 0; h < H_kv; ++h) {
3505 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3506 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3507 float *k_h = k + (size_t)h * kv_head_stride;
3508 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3513 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3514 for (
int h = 0; h < H_kv; ++h) {
3515 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3516 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3517 float *v_h = v + (size_t)h * kv_head_stride;
3518 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3545 const int K = H * aligned_head_dim;
3546 if (K != aligned_embed_dim) {
3549 const float *proj_in = attn_out;
3551 if (!proj_scratch) {
3554 for (
int t = 0; t < num_tokens; ++t) {
3555 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3556 for (
int h = 0; h < H; ++h) {
3557 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3558 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3560 (
size_t)aligned_head_dim *
sizeof(
float));
3563 proj_in = proj_scratch;
3565 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3581 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3582 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3583 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3595 int aligned_embed_dim,
3596 int aligned_head_dim,
3597 int aligned_intermediate_dim,
3598 int aligned_context_window
3625 const float *BQ = NULL;
3626 const float *BK = NULL;
3627 const float *BV = NULL;
3628 const float *BO = NULL;
3629 const float *B1 = NULL;
3630 const float *B2 = NULL;
3638 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3639 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3640 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3654 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3655 for (
int h = 0; h < H; ++h) {
3656 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3657 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3658 float *q_h = q + (size_t)h * q_head_stride;
3659 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3664 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3665 for (
int h = 0; h < H_kv; ++h) {
3666 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3667 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3668 float *k_h = k + (size_t)h * kv_head_stride;
3669 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3674 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3675 for (
int h = 0; h < H_kv; ++h) {
3676 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3677 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3678 float *v_h = v + (size_t)h * kv_head_stride;
3679 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3706 const int K = H * aligned_head_dim;
3707 if (K != aligned_embed_dim) {
3710 const float *proj_in = attn_out;
3712 if (!proj_scratch) {
3715 for (
int t = 0; t < num_tokens; ++t) {
3716 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3717 for (
int h = 0; h < H; ++h) {
3718 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3719 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3721 (
size_t)aligned_head_dim *
sizeof(
float));
3724 proj_in = proj_scratch;
3726 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3742 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3743 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3744 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3756 int aligned_embed_dim,
3757 int aligned_head_dim,
3758 int aligned_intermediate_dim,
3759 int aligned_context_window
3786 const float *BQ = NULL;
3787 const float *BK = NULL;
3788 const float *BV = NULL;
3789 const float *BO = NULL;
3790 const float *B1 = NULL;
3791 const float *B2 = NULL;
3799 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3800 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3801 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3815 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3816 for (
int h = 0; h < H; ++h) {
3817 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3818 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3819 float *q_h = q + (size_t)h * q_head_stride;
3820 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3825 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3826 for (
int h = 0; h < H_kv; ++h) {
3827 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3828 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3829 float *k_h = k + (size_t)h * kv_head_stride;
3830 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3835 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3836 for (
int h = 0; h < H_kv; ++h) {
3837 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3838 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3839 float *v_h = v + (size_t)h * kv_head_stride;
3840 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3867 const int K = H * aligned_head_dim;
3868 if (K != aligned_embed_dim) {
3871 const float *proj_in = attn_out;
3873 if (!proj_scratch) {
3876 for (
int t = 0; t < num_tokens; ++t) {
3877 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3878 for (
int h = 0; h < H; ++h) {
3879 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3880 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3882 (
size_t)aligned_head_dim *
sizeof(
float));
3885 proj_in = proj_scratch;
3887 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3903 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3904 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3905 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3917 int aligned_embed_dim,
3918 int aligned_head_dim,
3919 int aligned_intermediate_dim,
3920 int aligned_context_window
3947 const float *BQ = NULL;
3948 const float *BK = NULL;
3949 const float *BV = NULL;
3950 const float *BO = NULL;
3951 const float *B1 = NULL;
3952 const float *B2 = NULL;
3960 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3961 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3962 const size_t kv_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3976 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3977 for (
int h = 0; h < H; ++h) {
3978 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3979 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3980 float *q_h = q + (size_t)h * q_head_stride;
3981 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3986 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3987 for (
int h = 0; h < H_kv; ++h) {
3988 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3989 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3990 float *k_h = k + (size_t)h * kv_head_stride;
3991 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3996 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3997 for (
int h = 0; h < H_kv; ++h) {
3998 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3999 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4000 float *v_h = v + (size_t)h * kv_head_stride;
4001 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4028 const int K = H * aligned_head_dim;
4029 if (K != aligned_embed_dim) {
4032 const float *proj_in = attn_out;
4034 if (!proj_scratch) {
4037 for (
int t = 0; t < num_tokens; ++t) {
4038 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
4039 for (
int h = 0; h < H; ++h) {
4040 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
4041 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
4043 (
size_t)aligned_head_dim *
sizeof(
float));
4046 proj_in = proj_scratch;
4048 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
4064 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
4065 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
4066 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
4081 if (!model || !tokens || num_tokens <= 0) {
4086 const int aligned_embed_dim = 1024;
4087 const int aligned_head_dim = 64;
4088 const int aligned_intermediate_dim = 4864;
4089 const int aligned_context_window = 131072;
4109 aligned_intermediate_dim,
4110 aligned_context_window);
4116 aligned_context_window,
4122 aligned_context_window,
4130 aligned_intermediate_dim,
4131 aligned_context_window);
4137 aligned_context_window,
4143 aligned_context_window,
4151 aligned_intermediate_dim,
4152 aligned_context_window);
4158 aligned_context_window,
4164 aligned_context_window,
4172 aligned_intermediate_dim,
4173 aligned_context_window);
4179 aligned_context_window,
4185 aligned_context_window,
4193 aligned_intermediate_dim,
4194 aligned_context_window);
4200 aligned_context_window,
4206 aligned_context_window,
4214 aligned_intermediate_dim,
4215 aligned_context_window);
4221 aligned_context_window,
4227 aligned_context_window,
4235 aligned_intermediate_dim,
4236 aligned_context_window);
4242 aligned_context_window,
4248 aligned_context_window,
4256 aligned_intermediate_dim,
4257 aligned_context_window);
4263 aligned_context_window,
4269 aligned_context_window,
4277 aligned_intermediate_dim,
4278 aligned_context_window);
4284 aligned_context_window,
4290 aligned_context_window,
4298 aligned_intermediate_dim,
4299 aligned_context_window);
4305 aligned_context_window,
4311 aligned_context_window,
4319 aligned_intermediate_dim,
4320 aligned_context_window);
4326 aligned_context_window,
4332 aligned_context_window,
4340 aligned_intermediate_dim,
4341 aligned_context_window);
4347 aligned_context_window,
4353 aligned_context_window,
4361 aligned_intermediate_dim,
4362 aligned_context_window);
4368 aligned_context_window,
4374 aligned_context_window,
4382 aligned_intermediate_dim,
4383 aligned_context_window);
4389 aligned_context_window,
4395 aligned_context_window,
4403 aligned_intermediate_dim,
4404 aligned_context_window);
4410 aligned_context_window,
4416 aligned_context_window,
4424 aligned_intermediate_dim,
4425 aligned_context_window);
4431 aligned_context_window,
4437 aligned_context_window,
4445 aligned_intermediate_dim,
4446 aligned_context_window);
4452 aligned_context_window,
4458 aligned_context_window,
4466 aligned_intermediate_dim,
4467 aligned_context_window);
4473 aligned_context_window,
4479 aligned_context_window,
4487 aligned_intermediate_dim,
4488 aligned_context_window);
4494 aligned_context_window,
4500 aligned_context_window,
4508 aligned_intermediate_dim,
4509 aligned_context_window);
4515 aligned_context_window,
4521 aligned_context_window,
4529 aligned_intermediate_dim,
4530 aligned_context_window);
4536 aligned_context_window,
4542 aligned_context_window,
4550 aligned_intermediate_dim,
4551 aligned_context_window);
4557 aligned_context_window,
4563 aligned_context_window,
4571 aligned_intermediate_dim,
4572 aligned_context_window);
4578 aligned_context_window,
4584 aligned_context_window,
4592 aligned_intermediate_dim,
4593 aligned_context_window);
4599 aligned_context_window,
4605 aligned_context_window,
4623 for (
int t = 0; t < num_tokens; ++t) {
4624 uint8_t q8_buf[q8_bytes];
4625 const float *row = final_out + (size_t)t * (
size_t)aligned_embed_dim;
4648 int aligned_embed_dim,
4649 int aligned_head_dim,
4650 int aligned_intermediate_dim,
4651 int aligned_context_window
4684 float q_token[H * aligned_head_dim];
4685 float k_token[H_kv * aligned_head_dim];
4686 float v_token[H_kv * aligned_head_dim];
4687 float attn_token[H * aligned_head_dim];
4690 float fc1_out[2 * aligned_intermediate_dim];
4691 float swiglu_out[aligned_intermediate_dim];
4705 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4708 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
4711 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
4732 aligned_context_window,
4744 aligned_context_window,
4750 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4767 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4770 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4773 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4785 int aligned_embed_dim,
4786 int aligned_head_dim,
4787 int aligned_intermediate_dim,
4788 int aligned_context_window
4821 float q_token[H * aligned_head_dim];
4822 float k_token[H_kv * aligned_head_dim];
4823 float v_token[H_kv * aligned_head_dim];
4824 float attn_token[H * aligned_head_dim];
4827 float fc1_out[2 * aligned_intermediate_dim];
4828 float swiglu_out[aligned_intermediate_dim];
4842 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4845 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
4848 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
4869 aligned_context_window,
4881 aligned_context_window,
4887 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4904 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4907 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4910 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4922 int aligned_embed_dim,
4923 int aligned_head_dim,
4924 int aligned_intermediate_dim,
4925 int aligned_context_window
4958 float q_token[H * aligned_head_dim];
4959 float k_token[H_kv * aligned_head_dim];
4960 float v_token[H_kv * aligned_head_dim];
4961 float attn_token[H * aligned_head_dim];
4964 float fc1_out[2 * aligned_intermediate_dim];
4965 float swiglu_out[aligned_intermediate_dim];
4979 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4982 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
4985 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5006 aligned_context_window,
5018 aligned_context_window,
5024 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5041 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5044 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5047 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5059 int aligned_embed_dim,
5060 int aligned_head_dim,
5061 int aligned_intermediate_dim,
5062 int aligned_context_window
5095 float q_token[H * aligned_head_dim];
5096 float k_token[H_kv * aligned_head_dim];
5097 float v_token[H_kv * aligned_head_dim];
5098 float attn_token[H * aligned_head_dim];
5101 float fc1_out[2 * aligned_intermediate_dim];
5102 float swiglu_out[aligned_intermediate_dim];
5116 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5119 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5122 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5143 aligned_context_window,
5155 aligned_context_window,
5161 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5178 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5181 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5184 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5196 int aligned_embed_dim,
5197 int aligned_head_dim,
5198 int aligned_intermediate_dim,
5199 int aligned_context_window
5232 float q_token[H * aligned_head_dim];
5233 float k_token[H_kv * aligned_head_dim];
5234 float v_token[H_kv * aligned_head_dim];
5235 float attn_token[H * aligned_head_dim];
5238 float fc1_out[2 * aligned_intermediate_dim];
5239 float swiglu_out[aligned_intermediate_dim];
5253 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5256 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5259 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5280 aligned_context_window,
5292 aligned_context_window,
5298 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5315 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5318 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5321 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5333 int aligned_embed_dim,
5334 int aligned_head_dim,
5335 int aligned_intermediate_dim,
5336 int aligned_context_window
5369 float q_token[H * aligned_head_dim];
5370 float k_token[H_kv * aligned_head_dim];
5371 float v_token[H_kv * aligned_head_dim];
5372 float attn_token[H * aligned_head_dim];
5375 float fc1_out[2 * aligned_intermediate_dim];
5376 float swiglu_out[aligned_intermediate_dim];
5390 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5393 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5396 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5417 aligned_context_window,
5429 aligned_context_window,
5435 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5452 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5455 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5458 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5470 int aligned_embed_dim,
5471 int aligned_head_dim,
5472 int aligned_intermediate_dim,
5473 int aligned_context_window
5506 float q_token[H * aligned_head_dim];
5507 float k_token[H_kv * aligned_head_dim];
5508 float v_token[H_kv * aligned_head_dim];
5509 float attn_token[H * aligned_head_dim];
5512 float fc1_out[2 * aligned_intermediate_dim];
5513 float swiglu_out[aligned_intermediate_dim];
5527 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5530 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5533 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5554 aligned_context_window,
5566 aligned_context_window,
5572 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5589 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5592 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5595 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5607 int aligned_embed_dim,
5608 int aligned_head_dim,
5609 int aligned_intermediate_dim,
5610 int aligned_context_window
5643 float q_token[H * aligned_head_dim];
5644 float k_token[H_kv * aligned_head_dim];
5645 float v_token[H_kv * aligned_head_dim];
5646 float attn_token[H * aligned_head_dim];
5649 float fc1_out[2 * aligned_intermediate_dim];
5650 float swiglu_out[aligned_intermediate_dim];
5664 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5667 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5670 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5691 aligned_context_window,
5703 aligned_context_window,
5709 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5726 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5729 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5732 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5744 int aligned_embed_dim,
5745 int aligned_head_dim,
5746 int aligned_intermediate_dim,
5747 int aligned_context_window
5780 float q_token[H * aligned_head_dim];
5781 float k_token[H_kv * aligned_head_dim];
5782 float v_token[H_kv * aligned_head_dim];
5783 float attn_token[H * aligned_head_dim];
5786 float fc1_out[2 * aligned_intermediate_dim];
5787 float swiglu_out[aligned_intermediate_dim];
5801 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5804 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5807 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5828 aligned_context_window,
5840 aligned_context_window,
5846 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5863 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5866 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5869 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5881 int aligned_embed_dim,
5882 int aligned_head_dim,
5883 int aligned_intermediate_dim,
5884 int aligned_context_window
5917 float q_token[H * aligned_head_dim];
5918 float k_token[H_kv * aligned_head_dim];
5919 float v_token[H_kv * aligned_head_dim];
5920 float attn_token[H * aligned_head_dim];
5923 float fc1_out[2 * aligned_intermediate_dim];
5924 float swiglu_out[aligned_intermediate_dim];
5938 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5941 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
5944 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
5965 aligned_context_window,
5977 aligned_context_window,
5983 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6000 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6003 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6006 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6018 int aligned_embed_dim,
6019 int aligned_head_dim,
6020 int aligned_intermediate_dim,
6021 int aligned_context_window
6054 float q_token[H * aligned_head_dim];
6055 float k_token[H_kv * aligned_head_dim];
6056 float v_token[H_kv * aligned_head_dim];
6057 float attn_token[H * aligned_head_dim];
6060 float fc1_out[2 * aligned_intermediate_dim];
6061 float swiglu_out[aligned_intermediate_dim];
6075 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6078 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6081 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6102 aligned_context_window,
6114 aligned_context_window,
6120 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6137 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6140 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6143 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6155 int aligned_embed_dim,
6156 int aligned_head_dim,
6157 int aligned_intermediate_dim,
6158 int aligned_context_window
6191 float q_token[H * aligned_head_dim];
6192 float k_token[H_kv * aligned_head_dim];
6193 float v_token[H_kv * aligned_head_dim];
6194 float attn_token[H * aligned_head_dim];
6197 float fc1_out[2 * aligned_intermediate_dim];
6198 float swiglu_out[aligned_intermediate_dim];
6212 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6215 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6218 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6239 aligned_context_window,
6251 aligned_context_window,
6257 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6274 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6277 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6280 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6292 int aligned_embed_dim,
6293 int aligned_head_dim,
6294 int aligned_intermediate_dim,
6295 int aligned_context_window
6328 float q_token[H * aligned_head_dim];
6329 float k_token[H_kv * aligned_head_dim];
6330 float v_token[H_kv * aligned_head_dim];
6331 float attn_token[H * aligned_head_dim];
6334 float fc1_out[2 * aligned_intermediate_dim];
6335 float swiglu_out[aligned_intermediate_dim];
6349 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6352 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6355 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6376 aligned_context_window,
6388 aligned_context_window,
6394 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6411 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6414 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6417 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6429 int aligned_embed_dim,
6430 int aligned_head_dim,
6431 int aligned_intermediate_dim,
6432 int aligned_context_window
6465 float q_token[H * aligned_head_dim];
6466 float k_token[H_kv * aligned_head_dim];
6467 float v_token[H_kv * aligned_head_dim];
6468 float attn_token[H * aligned_head_dim];
6471 float fc1_out[2 * aligned_intermediate_dim];
6472 float swiglu_out[aligned_intermediate_dim];
6486 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6489 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6492 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6513 aligned_context_window,
6525 aligned_context_window,
6531 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6548 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6551 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6554 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6566 int aligned_embed_dim,
6567 int aligned_head_dim,
6568 int aligned_intermediate_dim,
6569 int aligned_context_window
6602 float q_token[H * aligned_head_dim];
6603 float k_token[H_kv * aligned_head_dim];
6604 float v_token[H_kv * aligned_head_dim];
6605 float attn_token[H * aligned_head_dim];
6608 float fc1_out[2 * aligned_intermediate_dim];
6609 float swiglu_out[aligned_intermediate_dim];
6623 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6626 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6629 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6650 aligned_context_window,
6662 aligned_context_window,
6668 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6685 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6688 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6691 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6703 int aligned_embed_dim,
6704 int aligned_head_dim,
6705 int aligned_intermediate_dim,
6706 int aligned_context_window
6739 float q_token[H * aligned_head_dim];
6740 float k_token[H_kv * aligned_head_dim];
6741 float v_token[H_kv * aligned_head_dim];
6742 float attn_token[H * aligned_head_dim];
6745 float fc1_out[2 * aligned_intermediate_dim];
6746 float swiglu_out[aligned_intermediate_dim];
6760 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6763 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6766 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6787 aligned_context_window,
6799 aligned_context_window,
6805 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6822 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6825 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6828 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6840 int aligned_embed_dim,
6841 int aligned_head_dim,
6842 int aligned_intermediate_dim,
6843 int aligned_context_window
6876 float q_token[H * aligned_head_dim];
6877 float k_token[H_kv * aligned_head_dim];
6878 float v_token[H_kv * aligned_head_dim];
6879 float attn_token[H * aligned_head_dim];
6882 float fc1_out[2 * aligned_intermediate_dim];
6883 float swiglu_out[aligned_intermediate_dim];
6897 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6900 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
6903 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
6924 aligned_context_window,
6936 aligned_context_window,
6942 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6959 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6962 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6965 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6977 int aligned_embed_dim,
6978 int aligned_head_dim,
6979 int aligned_intermediate_dim,
6980 int aligned_context_window
7013 float q_token[H * aligned_head_dim];
7014 float k_token[H_kv * aligned_head_dim];
7015 float v_token[H_kv * aligned_head_dim];
7016 float attn_token[H * aligned_head_dim];
7019 float fc1_out[2 * aligned_intermediate_dim];
7020 float swiglu_out[aligned_intermediate_dim];
7034 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7037 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7040 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7061 aligned_context_window,
7073 aligned_context_window,
7079 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7096 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7099 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7102 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7114 int aligned_embed_dim,
7115 int aligned_head_dim,
7116 int aligned_intermediate_dim,
7117 int aligned_context_window
7150 float q_token[H * aligned_head_dim];
7151 float k_token[H_kv * aligned_head_dim];
7152 float v_token[H_kv * aligned_head_dim];
7153 float attn_token[H * aligned_head_dim];
7156 float fc1_out[2 * aligned_intermediate_dim];
7157 float swiglu_out[aligned_intermediate_dim];
7171 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7174 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7177 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7198 aligned_context_window,
7210 aligned_context_window,
7216 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7233 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7236 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7239 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7251 int aligned_embed_dim,
7252 int aligned_head_dim,
7253 int aligned_intermediate_dim,
7254 int aligned_context_window
7287 float q_token[H * aligned_head_dim];
7288 float k_token[H_kv * aligned_head_dim];
7289 float v_token[H_kv * aligned_head_dim];
7290 float attn_token[H * aligned_head_dim];
7293 float fc1_out[2 * aligned_intermediate_dim];
7294 float swiglu_out[aligned_intermediate_dim];
7308 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7311 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7314 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7335 aligned_context_window,
7347 aligned_context_window,
7353 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7370 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7373 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7376 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7388 int aligned_embed_dim,
7389 int aligned_head_dim,
7390 int aligned_intermediate_dim,
7391 int aligned_context_window
7424 float q_token[H * aligned_head_dim];
7425 float k_token[H_kv * aligned_head_dim];
7426 float v_token[H_kv * aligned_head_dim];
7427 float attn_token[H * aligned_head_dim];
7430 float fc1_out[2 * aligned_intermediate_dim];
7431 float swiglu_out[aligned_intermediate_dim];
7445 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7448 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7451 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7472 aligned_context_window,
7484 aligned_context_window,
7490 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7507 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7510 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7513 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7525 int aligned_embed_dim,
7526 int aligned_head_dim,
7527 int aligned_intermediate_dim,
7528 int aligned_context_window
7561 float q_token[H * aligned_head_dim];
7562 float k_token[H_kv * aligned_head_dim];
7563 float v_token[H_kv * aligned_head_dim];
7564 float attn_token[H * aligned_head_dim];
7567 float fc1_out[2 * aligned_intermediate_dim];
7568 float swiglu_out[aligned_intermediate_dim];
7582 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7585 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7588 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7609 aligned_context_window,
7621 aligned_context_window,
7627 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7644 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7647 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7650 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7662 int aligned_embed_dim,
7663 int aligned_head_dim,
7664 int aligned_intermediate_dim,
7665 int aligned_context_window
7698 float q_token[H * aligned_head_dim];
7699 float k_token[H_kv * aligned_head_dim];
7700 float v_token[H_kv * aligned_head_dim];
7701 float attn_token[H * aligned_head_dim];
7704 float fc1_out[2 * aligned_intermediate_dim];
7705 float swiglu_out[aligned_intermediate_dim];
7719 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7722 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7725 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7746 aligned_context_window,
7758 aligned_context_window,
7764 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7781 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7784 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7787 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7799 int aligned_embed_dim,
7800 int aligned_head_dim,
7801 int aligned_intermediate_dim,
7802 int aligned_context_window
7835 float q_token[H * aligned_head_dim];
7836 float k_token[H_kv * aligned_head_dim];
7837 float v_token[H_kv * aligned_head_dim];
7838 float attn_token[H * aligned_head_dim];
7841 float fc1_out[2 * aligned_intermediate_dim];
7842 float swiglu_out[aligned_intermediate_dim];
7856 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7859 gemm_nt_q4_k(ln1_out, WK, NULL, k_token, 1, H_kv * head_dim, aligned_embed_dim);
7862 gemm_nt_q4_k(ln1_out, WV, NULL, v_token, 1, H_kv * head_dim, aligned_embed_dim);
7883 aligned_context_window,
7895 aligned_context_window,
7901 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7918 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7921 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7924 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7939 if (!model || !
token)
return;
7941 const int aligned_embed_dim = 1024;
7942 const int aligned_head_dim = 64;
7943 const int aligned_intermediate_dim = 4864;
7944 const int aligned_context_window = 131072;
7946 if (token_index < 0 || token_index >= aligned_context_window)
return;
8018 if (!model || !tokens || num_tokens <= 0)
return;
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void kv_cache_repack_head_major_inplace(float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void quantize_row_q8_k(const float *x, void *y, int k)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
#define QWEN2_0_5B_DECODE_TOTAL_BYTES
#define QWEN2_0_5B_DECODE_HEAD_DIM
#define QWEN2_0_5B_DECODE_PTR(model, offset)
#define QWEN2_0_5B_DECODE_ACTIVATION_BYTES
static const QWEN2_0_5B_DECODEFooterOffsets QWEN2_0_5B_DECODE_FOOTER
static const QWEN2_0_5B_DECODELayerOffsets QWEN2_0_5B_DECODE_LAYERS[24]
#define QWEN2_0_5B_DECODE_DTYPE_BYTES
#define QWEN2_0_5B_DECODE_EMBED_DIM
#define QWEN2_0_5B_DECODE_MAGIC
#define QWEN2_0_5B_DECODE_CANARY_COUNT
#define QWEN2_0_5B_DECODE_MAX_SEQ_LEN
#define QWEN2_0_5B_DECODE_NUM_LAYERS
#define QWEN2_0_5B_DECODE_WEIGHT_BYTES
#define QWEN2_0_5B_DECODE_CANARY_VALUE
static const QWEN2_0_5B_DECODEGlobalOffsets QWEN2_0_5B_DECODE_GLOBALS
#define QWEN2_0_5B_DECODE_NUM_KV_HEADS
#define QWEN2_0_5B_DECODE_NUM_HEADS
#define QWEN2_0_5B_DECODE_CANARY_SIZE
#define QWEN2_0_5B_DECODE_VOCAB_SIZE
static const QWEN2_0_5B_DECODEHeaderOffsets QWEN2_0_5B_DECODE_HEADER
static const QWEN2_0_5B_DECODECanary QWEN2_0_5B_DECODE_CANARIES[]
static void qwen2_0_5b_decode_layer_18_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_1_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_3_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_18_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_20_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_precompute_rope(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_layer_0_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int qwen2_0_5b_decode_verify_canaries(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
static void qwen2_0_5b_decode_layer_4_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static int qwen2_0_5b_decode_align_elems(int elems, int elem_bytes, int align_bytes)
static void qwen2_0_5b_decode_layer_6_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_5_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_13_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int qwen2_0_5b_decode_model_allocate(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_layer_0_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_6_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_decode(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
static void qwen2_0_5b_decode_layer_15_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_17_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_3_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_forward_prefill_impl(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
static void qwen2_0_5b_decode_layer_13_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_21_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_4_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_decode_token(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
static void qwen2_0_5b_decode_layer_21_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
struct __attribute__((packed))
static void qwen2_0_5b_decode_layer_1_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_forward(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
static void qwen2_0_5b_decode_layer_5_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_15_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_model_free(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_layer_20_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
_Static_assert(sizeof(MagicHeader)==64, "MagicHeader must be 64 bytes")
static void qwen2_0_5b_decode_layer_17_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
AUTO-GENERATED: qwen2_0.5b_decode Memory Layout.