35 #if QWEN2_0_5B_DECODE_DTYPE_BYTES != 4
36 #error "qwen2_0.5b_decode: v6.5 codegen currently supports fp32 only. Use --dtype=fp32."
50 if (!a || !b || !out) {
53 for (
int t = 0; t < tokens; ++t) {
54 const float *pa = a + (size_t)t * (
size_t)aligned_embed_dim;
55 const float *pb = b + (size_t)t * (
size_t)aligned_embed_dim;
56 float *pc = out + (size_t)t * (
size_t)aligned_embed_dim;
57 for (
int d = 0; d < aligned_embed_dim; ++d) {
58 pc[d] = pa[d] + pb[d];
71 uint64_t weight_bytes;
72 uint64_t activation_bytes;
78 uint32_t canary_count;
92 model->
base = mmap(NULL, total,
93 PROT_READ | PROT_WRITE,
94 MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
96 if (model->
base == MAP_FAILED) {
97 model->
base = mmap(NULL, total,
98 PROT_READ | PROT_WRITE,
99 MAP_PRIVATE | MAP_ANONYMOUS,
102 if (model->
base == MAP_FAILED) {
103 perror(
"mmap failed");
107 model->
base = aligned_alloc(64, total);
109 perror(
"aligned_alloc failed");
142 if (!model || !model->
base)
return;
158 for (
int j = 0; j < 4; j++) {
160 fprintf(stderr,
"CANARY CORRUPTION: %s at offset 0x%lX\n",
177 int bytes = elems * elem_bytes;
178 int aligned = (bytes + align_bytes - 1) / align_bytes * align_bytes;
179 return aligned / elem_bytes;
189 const float theta = 1000000.0f;
194 for (
int pos = 0; pos < T; pos++) {
195 for (
int i = 0; i < D; i++) {
196 float freq = 1.0f / powf(theta, (
float)(2 * i) / (
float)(D * 2));
197 float angle = (float)pos * freq;
198 cos_ptr[pos * D + i] = cosf(angle);
199 sin_ptr[pos * D + i] = sinf(angle);
214 int aligned_embed_dim,
215 int aligned_head_dim,
216 int aligned_intermediate_dim,
217 int aligned_context_window
244 const float *BQ = NULL;
245 const float *BK = NULL;
246 const float *BV = NULL;
247 const float *BO = NULL;
248 const float *B1 = NULL;
249 const float *B2 = NULL;
257 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
258 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
259 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
273 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
274 for (
int h = 0; h < H; ++h) {
275 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
276 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
277 float *q_h = q + (size_t)h * q_head_stride;
278 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
283 const uint8_t *WK_bytes = (
const uint8_t *)WK;
284 for (
int h = 0; h < H_kv; ++h) {
285 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
286 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
287 float *k_h = k + (size_t)h * kv_head_stride;
288 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
293 const uint8_t *WV_bytes = (
const uint8_t *)WV;
294 for (
int h = 0; h < H_kv; ++h) {
295 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
296 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
297 float *v_h = v + (size_t)h * kv_head_stride;
298 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
313 aligned_context_window);
325 aligned_context_window);
328 const int K = H * aligned_head_dim;
329 if (K != aligned_embed_dim) {
332 const float *proj_in = attn_out;
337 for (
int t = 0; t < num_tokens; ++t) {
338 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
339 for (
int h = 0; h < H; ++h) {
340 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
341 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
343 (
size_t)aligned_head_dim *
sizeof(
float));
346 proj_in = proj_scratch;
348 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
364 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
365 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
366 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
378 int aligned_embed_dim,
379 int aligned_head_dim,
380 int aligned_intermediate_dim,
381 int aligned_context_window
408 const float *BQ = NULL;
409 const float *BK = NULL;
410 const float *BV = NULL;
411 const float *BO = NULL;
412 const float *B1 = NULL;
413 const float *B2 = NULL;
421 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
422 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
423 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
437 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
438 for (
int h = 0; h < H; ++h) {
439 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
440 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
441 float *q_h = q + (size_t)h * q_head_stride;
442 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
447 const uint8_t *WK_bytes = (
const uint8_t *)WK;
448 for (
int h = 0; h < H_kv; ++h) {
449 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
450 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
451 float *k_h = k + (size_t)h * kv_head_stride;
452 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
457 const uint8_t *WV_bytes = (
const uint8_t *)WV;
458 for (
int h = 0; h < H_kv; ++h) {
459 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
460 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
461 float *v_h = v + (size_t)h * kv_head_stride;
462 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
477 aligned_context_window);
489 aligned_context_window);
492 const int K = H * aligned_head_dim;
493 if (K != aligned_embed_dim) {
496 const float *proj_in = attn_out;
501 for (
int t = 0; t < num_tokens; ++t) {
502 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
503 for (
int h = 0; h < H; ++h) {
504 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
505 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
507 (
size_t)aligned_head_dim *
sizeof(
float));
510 proj_in = proj_scratch;
512 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
528 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
529 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
530 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
542 int aligned_embed_dim,
543 int aligned_head_dim,
544 int aligned_intermediate_dim,
545 int aligned_context_window
572 const float *BQ = NULL;
573 const float *BK = NULL;
574 const float *BV = NULL;
575 const float *BO = NULL;
576 const float *B1 = NULL;
577 const float *B2 = NULL;
585 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
586 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
587 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
601 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
602 for (
int h = 0; h < H; ++h) {
603 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
604 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
605 float *q_h = q + (size_t)h * q_head_stride;
606 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
611 const uint8_t *WK_bytes = (
const uint8_t *)WK;
612 for (
int h = 0; h < H_kv; ++h) {
613 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
614 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
615 float *k_h = k + (size_t)h * kv_head_stride;
616 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
621 const uint8_t *WV_bytes = (
const uint8_t *)WV;
622 for (
int h = 0; h < H_kv; ++h) {
623 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
624 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
625 float *v_h = v + (size_t)h * kv_head_stride;
626 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
641 aligned_context_window);
653 aligned_context_window);
656 const int K = H * aligned_head_dim;
657 if (K != aligned_embed_dim) {
660 const float *proj_in = attn_out;
665 for (
int t = 0; t < num_tokens; ++t) {
666 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
667 for (
int h = 0; h < H; ++h) {
668 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
669 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
671 (
size_t)aligned_head_dim *
sizeof(
float));
674 proj_in = proj_scratch;
676 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
692 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
693 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
694 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
706 int aligned_embed_dim,
707 int aligned_head_dim,
708 int aligned_intermediate_dim,
709 int aligned_context_window
736 const float *BQ = NULL;
737 const float *BK = NULL;
738 const float *BV = NULL;
739 const float *BO = NULL;
740 const float *B1 = NULL;
741 const float *B2 = NULL;
749 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
750 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
751 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
765 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
766 for (
int h = 0; h < H; ++h) {
767 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
768 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
769 float *q_h = q + (size_t)h * q_head_stride;
770 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
775 const uint8_t *WK_bytes = (
const uint8_t *)WK;
776 for (
int h = 0; h < H_kv; ++h) {
777 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
778 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
779 float *k_h = k + (size_t)h * kv_head_stride;
780 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
785 const uint8_t *WV_bytes = (
const uint8_t *)WV;
786 for (
int h = 0; h < H_kv; ++h) {
787 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
788 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
789 float *v_h = v + (size_t)h * kv_head_stride;
790 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
805 aligned_context_window);
817 aligned_context_window);
820 const int K = H * aligned_head_dim;
821 if (K != aligned_embed_dim) {
824 const float *proj_in = attn_out;
829 for (
int t = 0; t < num_tokens; ++t) {
830 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
831 for (
int h = 0; h < H; ++h) {
832 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
833 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
835 (
size_t)aligned_head_dim *
sizeof(
float));
838 proj_in = proj_scratch;
840 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
856 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
857 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
858 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
870 int aligned_embed_dim,
871 int aligned_head_dim,
872 int aligned_intermediate_dim,
873 int aligned_context_window
900 const float *BQ = NULL;
901 const float *BK = NULL;
902 const float *BV = NULL;
903 const float *BO = NULL;
904 const float *B1 = NULL;
905 const float *B2 = NULL;
913 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
914 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
915 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
929 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
930 for (
int h = 0; h < H; ++h) {
931 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
932 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
933 float *q_h = q + (size_t)h * q_head_stride;
934 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
939 const uint8_t *WK_bytes = (
const uint8_t *)WK;
940 for (
int h = 0; h < H_kv; ++h) {
941 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
942 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
943 float *k_h = k + (size_t)h * kv_head_stride;
944 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
949 const uint8_t *WV_bytes = (
const uint8_t *)WV;
950 for (
int h = 0; h < H_kv; ++h) {
951 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
952 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
953 float *v_h = v + (size_t)h * kv_head_stride;
954 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
969 aligned_context_window);
981 aligned_context_window);
984 const int K = H * aligned_head_dim;
985 if (K != aligned_embed_dim) {
988 const float *proj_in = attn_out;
993 for (
int t = 0; t < num_tokens; ++t) {
994 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
995 for (
int h = 0; h < H; ++h) {
996 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
997 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
999 (
size_t)aligned_head_dim *
sizeof(
float));
1002 proj_in = proj_scratch;
1004 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1020 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1021 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1022 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1034 int aligned_embed_dim,
1035 int aligned_head_dim,
1036 int aligned_intermediate_dim,
1037 int aligned_context_window
1064 const float *BQ = NULL;
1065 const float *BK = NULL;
1066 const float *BV = NULL;
1067 const float *BO = NULL;
1068 const float *B1 = NULL;
1069 const float *B2 = NULL;
1077 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1078 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1079 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1093 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1094 for (
int h = 0; h < H; ++h) {
1095 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1096 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1097 float *q_h = q + (size_t)h * q_head_stride;
1098 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1103 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1104 for (
int h = 0; h < H_kv; ++h) {
1105 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1106 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1107 float *k_h = k + (size_t)h * kv_head_stride;
1108 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1113 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1114 for (
int h = 0; h < H_kv; ++h) {
1115 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1116 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1117 float *v_h = v + (size_t)h * kv_head_stride;
1118 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1133 aligned_context_window);
1145 aligned_context_window);
1148 const int K = H * aligned_head_dim;
1149 if (K != aligned_embed_dim) {
1152 const float *proj_in = attn_out;
1154 if (!proj_scratch) {
1157 for (
int t = 0; t < num_tokens; ++t) {
1158 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1159 for (
int h = 0; h < H; ++h) {
1160 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1161 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1163 (
size_t)aligned_head_dim *
sizeof(
float));
1166 proj_in = proj_scratch;
1168 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1184 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1185 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1186 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1198 int aligned_embed_dim,
1199 int aligned_head_dim,
1200 int aligned_intermediate_dim,
1201 int aligned_context_window
1228 const float *BQ = NULL;
1229 const float *BK = NULL;
1230 const float *BV = NULL;
1231 const float *BO = NULL;
1232 const float *B1 = NULL;
1233 const float *B2 = NULL;
1241 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1242 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1243 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1257 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1258 for (
int h = 0; h < H; ++h) {
1259 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1260 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1261 float *q_h = q + (size_t)h * q_head_stride;
1262 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1267 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1268 for (
int h = 0; h < H_kv; ++h) {
1269 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1270 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1271 float *k_h = k + (size_t)h * kv_head_stride;
1272 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1277 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1278 for (
int h = 0; h < H_kv; ++h) {
1279 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1280 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1281 float *v_h = v + (size_t)h * kv_head_stride;
1282 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1297 aligned_context_window);
1309 aligned_context_window);
1312 const int K = H * aligned_head_dim;
1313 if (K != aligned_embed_dim) {
1316 const float *proj_in = attn_out;
1318 if (!proj_scratch) {
1321 for (
int t = 0; t < num_tokens; ++t) {
1322 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1323 for (
int h = 0; h < H; ++h) {
1324 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1325 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1327 (
size_t)aligned_head_dim *
sizeof(
float));
1330 proj_in = proj_scratch;
1332 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1348 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1349 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1350 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1362 int aligned_embed_dim,
1363 int aligned_head_dim,
1364 int aligned_intermediate_dim,
1365 int aligned_context_window
1392 const float *BQ = NULL;
1393 const float *BK = NULL;
1394 const float *BV = NULL;
1395 const float *BO = NULL;
1396 const float *B1 = NULL;
1397 const float *B2 = NULL;
1405 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1406 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1407 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1421 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1422 for (
int h = 0; h < H; ++h) {
1423 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1424 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1425 float *q_h = q + (size_t)h * q_head_stride;
1426 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1431 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1432 for (
int h = 0; h < H_kv; ++h) {
1433 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1434 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1435 float *k_h = k + (size_t)h * kv_head_stride;
1436 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1441 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1442 for (
int h = 0; h < H_kv; ++h) {
1443 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1444 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1445 float *v_h = v + (size_t)h * kv_head_stride;
1446 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1461 aligned_context_window);
1473 aligned_context_window);
1476 const int K = H * aligned_head_dim;
1477 if (K != aligned_embed_dim) {
1480 const float *proj_in = attn_out;
1482 if (!proj_scratch) {
1485 for (
int t = 0; t < num_tokens; ++t) {
1486 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1487 for (
int h = 0; h < H; ++h) {
1488 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1489 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1491 (
size_t)aligned_head_dim *
sizeof(
float));
1494 proj_in = proj_scratch;
1496 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1512 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1513 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1514 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1526 int aligned_embed_dim,
1527 int aligned_head_dim,
1528 int aligned_intermediate_dim,
1529 int aligned_context_window
1556 const float *BQ = NULL;
1557 const float *BK = NULL;
1558 const float *BV = NULL;
1559 const float *BO = NULL;
1560 const float *B1 = NULL;
1561 const float *B2 = NULL;
1569 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1570 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1571 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1585 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1586 for (
int h = 0; h < H; ++h) {
1587 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1588 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1589 float *q_h = q + (size_t)h * q_head_stride;
1590 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1595 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1596 for (
int h = 0; h < H_kv; ++h) {
1597 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1598 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1599 float *k_h = k + (size_t)h * kv_head_stride;
1600 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1605 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1606 for (
int h = 0; h < H_kv; ++h) {
1607 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1608 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1609 float *v_h = v + (size_t)h * kv_head_stride;
1610 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1625 aligned_context_window);
1637 aligned_context_window);
1640 const int K = H * aligned_head_dim;
1641 if (K != aligned_embed_dim) {
1644 const float *proj_in = attn_out;
1646 if (!proj_scratch) {
1649 for (
int t = 0; t < num_tokens; ++t) {
1650 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1651 for (
int h = 0; h < H; ++h) {
1652 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1653 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1655 (
size_t)aligned_head_dim *
sizeof(
float));
1658 proj_in = proj_scratch;
1660 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1676 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1677 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1678 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1690 int aligned_embed_dim,
1691 int aligned_head_dim,
1692 int aligned_intermediate_dim,
1693 int aligned_context_window
1720 const float *BQ = NULL;
1721 const float *BK = NULL;
1722 const float *BV = NULL;
1723 const float *BO = NULL;
1724 const float *B1 = NULL;
1725 const float *B2 = NULL;
1733 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1734 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1735 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1749 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1750 for (
int h = 0; h < H; ++h) {
1751 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1752 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1753 float *q_h = q + (size_t)h * q_head_stride;
1754 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1759 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1760 for (
int h = 0; h < H_kv; ++h) {
1761 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1762 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1763 float *k_h = k + (size_t)h * kv_head_stride;
1764 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1769 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1770 for (
int h = 0; h < H_kv; ++h) {
1771 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1772 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1773 float *v_h = v + (size_t)h * kv_head_stride;
1774 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1789 aligned_context_window);
1801 aligned_context_window);
1804 const int K = H * aligned_head_dim;
1805 if (K != aligned_embed_dim) {
1808 const float *proj_in = attn_out;
1810 if (!proj_scratch) {
1813 for (
int t = 0; t < num_tokens; ++t) {
1814 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1815 for (
int h = 0; h < H; ++h) {
1816 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1817 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1819 (
size_t)aligned_head_dim *
sizeof(
float));
1822 proj_in = proj_scratch;
1824 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1840 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1841 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1842 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1854 int aligned_embed_dim,
1855 int aligned_head_dim,
1856 int aligned_intermediate_dim,
1857 int aligned_context_window
1884 const float *BQ = NULL;
1885 const float *BK = NULL;
1886 const float *BV = NULL;
1887 const float *BO = NULL;
1888 const float *B1 = NULL;
1889 const float *B2 = NULL;
1897 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1898 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1899 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1913 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1914 for (
int h = 0; h < H; ++h) {
1915 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1916 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1917 float *q_h = q + (size_t)h * q_head_stride;
1918 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1923 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1924 for (
int h = 0; h < H_kv; ++h) {
1925 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1926 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1927 float *k_h = k + (size_t)h * kv_head_stride;
1928 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1933 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1934 for (
int h = 0; h < H_kv; ++h) {
1935 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1936 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1937 float *v_h = v + (size_t)h * kv_head_stride;
1938 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1953 aligned_context_window);
1965 aligned_context_window);
1968 const int K = H * aligned_head_dim;
1969 if (K != aligned_embed_dim) {
1972 const float *proj_in = attn_out;
1974 if (!proj_scratch) {
1977 for (
int t = 0; t < num_tokens; ++t) {
1978 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1979 for (
int h = 0; h < H; ++h) {
1980 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1981 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1983 (
size_t)aligned_head_dim *
sizeof(
float));
1986 proj_in = proj_scratch;
1988 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2004 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2005 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2006 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2018 int aligned_embed_dim,
2019 int aligned_head_dim,
2020 int aligned_intermediate_dim,
2021 int aligned_context_window
2048 const float *BQ = NULL;
2049 const float *BK = NULL;
2050 const float *BV = NULL;
2051 const float *BO = NULL;
2052 const float *B1 = NULL;
2053 const float *B2 = NULL;
2061 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2062 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2063 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2077 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2078 for (
int h = 0; h < H; ++h) {
2079 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2080 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2081 float *q_h = q + (size_t)h * q_head_stride;
2082 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2087 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2088 for (
int h = 0; h < H_kv; ++h) {
2089 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2090 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2091 float *k_h = k + (size_t)h * kv_head_stride;
2092 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2097 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2098 for (
int h = 0; h < H_kv; ++h) {
2099 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2100 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2101 float *v_h = v + (size_t)h * kv_head_stride;
2102 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2117 aligned_context_window);
2129 aligned_context_window);
2132 const int K = H * aligned_head_dim;
2133 if (K != aligned_embed_dim) {
2136 const float *proj_in = attn_out;
2138 if (!proj_scratch) {
2141 for (
int t = 0; t < num_tokens; ++t) {
2142 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2143 for (
int h = 0; h < H; ++h) {
2144 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2145 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2147 (
size_t)aligned_head_dim *
sizeof(
float));
2150 proj_in = proj_scratch;
2152 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2168 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2169 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2170 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2182 int aligned_embed_dim,
2183 int aligned_head_dim,
2184 int aligned_intermediate_dim,
2185 int aligned_context_window
2212 const float *BQ = NULL;
2213 const float *BK = NULL;
2214 const float *BV = NULL;
2215 const float *BO = NULL;
2216 const float *B1 = NULL;
2217 const float *B2 = NULL;
2225 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2226 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2227 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2241 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2242 for (
int h = 0; h < H; ++h) {
2243 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2244 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2245 float *q_h = q + (size_t)h * q_head_stride;
2246 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2251 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2252 for (
int h = 0; h < H_kv; ++h) {
2253 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2254 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2255 float *k_h = k + (size_t)h * kv_head_stride;
2256 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2261 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2262 for (
int h = 0; h < H_kv; ++h) {
2263 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2264 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2265 float *v_h = v + (size_t)h * kv_head_stride;
2266 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2281 aligned_context_window);
2293 aligned_context_window);
2296 const int K = H * aligned_head_dim;
2297 if (K != aligned_embed_dim) {
2300 const float *proj_in = attn_out;
2302 if (!proj_scratch) {
2305 for (
int t = 0; t < num_tokens; ++t) {
2306 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2307 for (
int h = 0; h < H; ++h) {
2308 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2309 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2311 (
size_t)aligned_head_dim *
sizeof(
float));
2314 proj_in = proj_scratch;
2316 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2332 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2333 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2334 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2346 int aligned_embed_dim,
2347 int aligned_head_dim,
2348 int aligned_intermediate_dim,
2349 int aligned_context_window
2376 const float *BQ = NULL;
2377 const float *BK = NULL;
2378 const float *BV = NULL;
2379 const float *BO = NULL;
2380 const float *B1 = NULL;
2381 const float *B2 = NULL;
2389 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2390 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2391 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2405 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2406 for (
int h = 0; h < H; ++h) {
2407 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2408 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2409 float *q_h = q + (size_t)h * q_head_stride;
2410 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2415 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2416 for (
int h = 0; h < H_kv; ++h) {
2417 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2418 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2419 float *k_h = k + (size_t)h * kv_head_stride;
2420 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2425 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2426 for (
int h = 0; h < H_kv; ++h) {
2427 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2428 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2429 float *v_h = v + (size_t)h * kv_head_stride;
2430 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2445 aligned_context_window);
2457 aligned_context_window);
2460 const int K = H * aligned_head_dim;
2461 if (K != aligned_embed_dim) {
2464 const float *proj_in = attn_out;
2466 if (!proj_scratch) {
2469 for (
int t = 0; t < num_tokens; ++t) {
2470 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2471 for (
int h = 0; h < H; ++h) {
2472 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2473 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2475 (
size_t)aligned_head_dim *
sizeof(
float));
2478 proj_in = proj_scratch;
2480 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2496 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2497 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2498 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2510 int aligned_embed_dim,
2511 int aligned_head_dim,
2512 int aligned_intermediate_dim,
2513 int aligned_context_window
2540 const float *BQ = NULL;
2541 const float *BK = NULL;
2542 const float *BV = NULL;
2543 const float *BO = NULL;
2544 const float *B1 = NULL;
2545 const float *B2 = NULL;
2553 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2554 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2555 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2569 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2570 for (
int h = 0; h < H; ++h) {
2571 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2572 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2573 float *q_h = q + (size_t)h * q_head_stride;
2574 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2579 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2580 for (
int h = 0; h < H_kv; ++h) {
2581 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2582 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2583 float *k_h = k + (size_t)h * kv_head_stride;
2584 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2589 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2590 for (
int h = 0; h < H_kv; ++h) {
2591 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2592 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2593 float *v_h = v + (size_t)h * kv_head_stride;
2594 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2609 aligned_context_window);
2621 aligned_context_window);
2624 const int K = H * aligned_head_dim;
2625 if (K != aligned_embed_dim) {
2628 const float *proj_in = attn_out;
2630 if (!proj_scratch) {
2633 for (
int t = 0; t < num_tokens; ++t) {
2634 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2635 for (
int h = 0; h < H; ++h) {
2636 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2637 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2639 (
size_t)aligned_head_dim *
sizeof(
float));
2642 proj_in = proj_scratch;
2644 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2660 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2661 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2662 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2674 int aligned_embed_dim,
2675 int aligned_head_dim,
2676 int aligned_intermediate_dim,
2677 int aligned_context_window
2704 const float *BQ = NULL;
2705 const float *BK = NULL;
2706 const float *BV = NULL;
2707 const float *BO = NULL;
2708 const float *B1 = NULL;
2709 const float *B2 = NULL;
2717 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2718 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2719 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2733 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2734 for (
int h = 0; h < H; ++h) {
2735 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2736 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2737 float *q_h = q + (size_t)h * q_head_stride;
2738 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2743 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2744 for (
int h = 0; h < H_kv; ++h) {
2745 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2746 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2747 float *k_h = k + (size_t)h * kv_head_stride;
2748 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2753 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2754 for (
int h = 0; h < H_kv; ++h) {
2755 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2756 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2757 float *v_h = v + (size_t)h * kv_head_stride;
2758 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2773 aligned_context_window);
2785 aligned_context_window);
2788 const int K = H * aligned_head_dim;
2789 if (K != aligned_embed_dim) {
2792 const float *proj_in = attn_out;
2794 if (!proj_scratch) {
2797 for (
int t = 0; t < num_tokens; ++t) {
2798 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2799 for (
int h = 0; h < H; ++h) {
2800 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2801 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2803 (
size_t)aligned_head_dim *
sizeof(
float));
2806 proj_in = proj_scratch;
2808 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2824 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2825 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2826 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2838 int aligned_embed_dim,
2839 int aligned_head_dim,
2840 int aligned_intermediate_dim,
2841 int aligned_context_window
2868 const float *BQ = NULL;
2869 const float *BK = NULL;
2870 const float *BV = NULL;
2871 const float *BO = NULL;
2872 const float *B1 = NULL;
2873 const float *B2 = NULL;
2881 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2882 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2883 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2897 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2898 for (
int h = 0; h < H; ++h) {
2899 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2900 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2901 float *q_h = q + (size_t)h * q_head_stride;
2902 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2907 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2908 for (
int h = 0; h < H_kv; ++h) {
2909 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2910 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2911 float *k_h = k + (size_t)h * kv_head_stride;
2912 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2917 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2918 for (
int h = 0; h < H_kv; ++h) {
2919 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2920 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2921 float *v_h = v + (size_t)h * kv_head_stride;
2922 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2937 aligned_context_window);
2949 aligned_context_window);
2952 const int K = H * aligned_head_dim;
2953 if (K != aligned_embed_dim) {
2956 const float *proj_in = attn_out;
2958 if (!proj_scratch) {
2961 for (
int t = 0; t < num_tokens; ++t) {
2962 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2963 for (
int h = 0; h < H; ++h) {
2964 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2965 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2967 (
size_t)aligned_head_dim *
sizeof(
float));
2970 proj_in = proj_scratch;
2972 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2988 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2989 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2990 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3002 int aligned_embed_dim,
3003 int aligned_head_dim,
3004 int aligned_intermediate_dim,
3005 int aligned_context_window
3032 const float *BQ = NULL;
3033 const float *BK = NULL;
3034 const float *BV = NULL;
3035 const float *BO = NULL;
3036 const float *B1 = NULL;
3037 const float *B2 = NULL;
3045 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3046 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3047 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3061 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3062 for (
int h = 0; h < H; ++h) {
3063 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3064 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3065 float *q_h = q + (size_t)h * q_head_stride;
3066 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3071 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3072 for (
int h = 0; h < H_kv; ++h) {
3073 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3074 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3075 float *k_h = k + (size_t)h * kv_head_stride;
3076 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3081 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3082 for (
int h = 0; h < H_kv; ++h) {
3083 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3084 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3085 float *v_h = v + (size_t)h * kv_head_stride;
3086 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3101 aligned_context_window);
3113 aligned_context_window);
3116 const int K = H * aligned_head_dim;
3117 if (K != aligned_embed_dim) {
3120 const float *proj_in = attn_out;
3122 if (!proj_scratch) {
3125 for (
int t = 0; t < num_tokens; ++t) {
3126 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3127 for (
int h = 0; h < H; ++h) {
3128 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3129 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3131 (
size_t)aligned_head_dim *
sizeof(
float));
3134 proj_in = proj_scratch;
3136 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3152 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3153 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3154 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3166 int aligned_embed_dim,
3167 int aligned_head_dim,
3168 int aligned_intermediate_dim,
3169 int aligned_context_window
3196 const float *BQ = NULL;
3197 const float *BK = NULL;
3198 const float *BV = NULL;
3199 const float *BO = NULL;
3200 const float *B1 = NULL;
3201 const float *B2 = NULL;
3209 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3210 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3211 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3225 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3226 for (
int h = 0; h < H; ++h) {
3227 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3228 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3229 float *q_h = q + (size_t)h * q_head_stride;
3230 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3235 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3236 for (
int h = 0; h < H_kv; ++h) {
3237 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3238 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3239 float *k_h = k + (size_t)h * kv_head_stride;
3240 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3245 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3246 for (
int h = 0; h < H_kv; ++h) {
3247 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3248 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3249 float *v_h = v + (size_t)h * kv_head_stride;
3250 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3265 aligned_context_window);
3277 aligned_context_window);
3280 const int K = H * aligned_head_dim;
3281 if (K != aligned_embed_dim) {
3284 const float *proj_in = attn_out;
3286 if (!proj_scratch) {
3289 for (
int t = 0; t < num_tokens; ++t) {
3290 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3291 for (
int h = 0; h < H; ++h) {
3292 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3293 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3295 (
size_t)aligned_head_dim *
sizeof(
float));
3298 proj_in = proj_scratch;
3300 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3316 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3317 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3318 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3330 int aligned_embed_dim,
3331 int aligned_head_dim,
3332 int aligned_intermediate_dim,
3333 int aligned_context_window
3360 const float *BQ = NULL;
3361 const float *BK = NULL;
3362 const float *BV = NULL;
3363 const float *BO = NULL;
3364 const float *B1 = NULL;
3365 const float *B2 = NULL;
3373 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3374 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3375 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3389 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3390 for (
int h = 0; h < H; ++h) {
3391 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3392 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3393 float *q_h = q + (size_t)h * q_head_stride;
3394 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3399 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3400 for (
int h = 0; h < H_kv; ++h) {
3401 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3402 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3403 float *k_h = k + (size_t)h * kv_head_stride;
3404 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3409 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3410 for (
int h = 0; h < H_kv; ++h) {
3411 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3412 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3413 float *v_h = v + (size_t)h * kv_head_stride;
3414 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3429 aligned_context_window);
3441 aligned_context_window);
3444 const int K = H * aligned_head_dim;
3445 if (K != aligned_embed_dim) {
3448 const float *proj_in = attn_out;
3450 if (!proj_scratch) {
3453 for (
int t = 0; t < num_tokens; ++t) {
3454 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3455 for (
int h = 0; h < H; ++h) {
3456 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3457 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3459 (
size_t)aligned_head_dim *
sizeof(
float));
3462 proj_in = proj_scratch;
3464 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3480 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3481 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3482 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3494 int aligned_embed_dim,
3495 int aligned_head_dim,
3496 int aligned_intermediate_dim,
3497 int aligned_context_window
3524 const float *BQ = NULL;
3525 const float *BK = NULL;
3526 const float *BV = NULL;
3527 const float *BO = NULL;
3528 const float *B1 = NULL;
3529 const float *B2 = NULL;
3537 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3538 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3539 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3553 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3554 for (
int h = 0; h < H; ++h) {
3555 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3556 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3557 float *q_h = q + (size_t)h * q_head_stride;
3558 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3563 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3564 for (
int h = 0; h < H_kv; ++h) {
3565 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3566 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3567 float *k_h = k + (size_t)h * kv_head_stride;
3568 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3573 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3574 for (
int h = 0; h < H_kv; ++h) {
3575 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3576 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3577 float *v_h = v + (size_t)h * kv_head_stride;
3578 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3593 aligned_context_window);
3605 aligned_context_window);
3608 const int K = H * aligned_head_dim;
3609 if (K != aligned_embed_dim) {
3612 const float *proj_in = attn_out;
3614 if (!proj_scratch) {
3617 for (
int t = 0; t < num_tokens; ++t) {
3618 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3619 for (
int h = 0; h < H; ++h) {
3620 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3621 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3623 (
size_t)aligned_head_dim *
sizeof(
float));
3626 proj_in = proj_scratch;
3628 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3644 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3645 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3646 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3658 int aligned_embed_dim,
3659 int aligned_head_dim,
3660 int aligned_intermediate_dim,
3661 int aligned_context_window
3688 const float *BQ = NULL;
3689 const float *BK = NULL;
3690 const float *BV = NULL;
3691 const float *BO = NULL;
3692 const float *B1 = NULL;
3693 const float *B2 = NULL;
3701 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3702 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3703 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3717 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3718 for (
int h = 0; h < H; ++h) {
3719 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3720 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3721 float *q_h = q + (size_t)h * q_head_stride;
3722 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3727 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3728 for (
int h = 0; h < H_kv; ++h) {
3729 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3730 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3731 float *k_h = k + (size_t)h * kv_head_stride;
3732 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3737 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3738 for (
int h = 0; h < H_kv; ++h) {
3739 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3740 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3741 float *v_h = v + (size_t)h * kv_head_stride;
3742 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3757 aligned_context_window);
3769 aligned_context_window);
3772 const int K = H * aligned_head_dim;
3773 if (K != aligned_embed_dim) {
3776 const float *proj_in = attn_out;
3778 if (!proj_scratch) {
3781 for (
int t = 0; t < num_tokens; ++t) {
3782 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3783 for (
int h = 0; h < H; ++h) {
3784 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3785 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3787 (
size_t)aligned_head_dim *
sizeof(
float));
3790 proj_in = proj_scratch;
3792 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3808 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3809 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3810 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3822 int aligned_embed_dim,
3823 int aligned_head_dim,
3824 int aligned_intermediate_dim,
3825 int aligned_context_window
3852 const float *BQ = NULL;
3853 const float *BK = NULL;
3854 const float *BV = NULL;
3855 const float *BO = NULL;
3856 const float *B1 = NULL;
3857 const float *B2 = NULL;
3865 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3866 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3867 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3881 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3882 for (
int h = 0; h < H; ++h) {
3883 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3884 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3885 float *q_h = q + (size_t)h * q_head_stride;
3886 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3891 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3892 for (
int h = 0; h < H_kv; ++h) {
3893 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3894 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3895 float *k_h = k + (size_t)h * kv_head_stride;
3896 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3901 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3902 for (
int h = 0; h < H_kv; ++h) {
3903 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3904 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3905 float *v_h = v + (size_t)h * kv_head_stride;
3906 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3921 aligned_context_window);
3933 aligned_context_window);
3936 const int K = H * aligned_head_dim;
3937 if (K != aligned_embed_dim) {
3940 const float *proj_in = attn_out;
3942 if (!proj_scratch) {
3945 for (
int t = 0; t < num_tokens; ++t) {
3946 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3947 for (
int h = 0; h < H; ++h) {
3948 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3949 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3951 (
size_t)aligned_head_dim *
sizeof(
float));
3954 proj_in = proj_scratch;
3956 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3972 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3973 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3974 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3986 int aligned_embed_dim,
3987 int aligned_head_dim,
3988 int aligned_intermediate_dim,
3989 int aligned_context_window
4016 const float *BQ = NULL;
4017 const float *BK = NULL;
4018 const float *BV = NULL;
4019 const float *BO = NULL;
4020 const float *B1 = NULL;
4021 const float *B2 = NULL;
4029 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4030 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
4031 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4045 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
4046 for (
int h = 0; h < H; ++h) {
4047 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
4048 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4049 float *q_h = q + (size_t)h * q_head_stride;
4050 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4055 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4056 for (
int h = 0; h < H_kv; ++h) {
4057 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4058 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4059 float *k_h = k + (size_t)h * kv_head_stride;
4060 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4065 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4066 for (
int h = 0; h < H_kv; ++h) {
4067 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4068 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4069 float *v_h = v + (size_t)h * kv_head_stride;
4070 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4085 aligned_context_window);
4097 aligned_context_window);
4100 const int K = H * aligned_head_dim;
4101 if (K != aligned_embed_dim) {
4104 const float *proj_in = attn_out;
4106 if (!proj_scratch) {
4109 for (
int t = 0; t < num_tokens; ++t) {
4110 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
4111 for (
int h = 0; h < H; ++h) {
4112 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
4113 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
4115 (
size_t)aligned_head_dim *
sizeof(
float));
4118 proj_in = proj_scratch;
4120 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
4136 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
4137 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
4138 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
4153 if (!model || !tokens || num_tokens <= 0) {
4158 const int aligned_embed_dim = 1024;
4159 const int aligned_head_dim = 64;
4160 const int aligned_intermediate_dim = 4864;
4161 const int aligned_context_window = 131072;
4181 aligned_intermediate_dim,
4182 aligned_context_window);
4189 aligned_intermediate_dim,
4190 aligned_context_window);
4197 aligned_intermediate_dim,
4198 aligned_context_window);
4205 aligned_intermediate_dim,
4206 aligned_context_window);
4213 aligned_intermediate_dim,
4214 aligned_context_window);
4221 aligned_intermediate_dim,
4222 aligned_context_window);
4229 aligned_intermediate_dim,
4230 aligned_context_window);
4237 aligned_intermediate_dim,
4238 aligned_context_window);
4245 aligned_intermediate_dim,
4246 aligned_context_window);
4253 aligned_intermediate_dim,
4254 aligned_context_window);
4261 aligned_intermediate_dim,
4262 aligned_context_window);
4269 aligned_intermediate_dim,
4270 aligned_context_window);
4277 aligned_intermediate_dim,
4278 aligned_context_window);
4285 aligned_intermediate_dim,
4286 aligned_context_window);
4293 aligned_intermediate_dim,
4294 aligned_context_window);
4301 aligned_intermediate_dim,
4302 aligned_context_window);
4309 aligned_intermediate_dim,
4310 aligned_context_window);
4317 aligned_intermediate_dim,
4318 aligned_context_window);
4325 aligned_intermediate_dim,
4326 aligned_context_window);
4333 aligned_intermediate_dim,
4334 aligned_context_window);
4341 aligned_intermediate_dim,
4342 aligned_context_window);
4349 aligned_intermediate_dim,
4350 aligned_context_window);
4357 aligned_intermediate_dim,
4358 aligned_context_window);
4365 aligned_intermediate_dim,
4366 aligned_context_window);
4383 for (
int t = 0; t < num_tokens; ++t) {
4384 uint8_t q8_buf[q8_bytes];
4385 const float *row = final_out + (size_t)t * (
size_t)aligned_embed_dim;
4408 int aligned_embed_dim,
4409 int aligned_head_dim,
4410 int aligned_intermediate_dim,
4411 int aligned_context_window
4444 float q_token[H * aligned_head_dim];
4445 float k_token[H_kv * aligned_head_dim];
4446 float v_token[H_kv * aligned_head_dim];
4447 float attn_token[H * aligned_head_dim];
4450 float fc1_out[2 * aligned_intermediate_dim];
4451 float swiglu_out[aligned_intermediate_dim];
4463 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4467 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4468 uint8_t ln1_q8[ln1_q8_bytes];
4470 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4471 if (aligned_head_dim > head_dim) {
4472 for (
int h = 0; h < H; ++h) {
4473 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4474 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4481 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4483 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4485 for (
int h = 0; h < H_kv; ++h) {
4486 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4487 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4488 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4489 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4495 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4497 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4499 for (
int h = 0; h < H_kv; ++h) {
4500 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4501 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4502 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4503 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4517 for (
int h = 0; h < H_kv; ++h) {
4518 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4539 aligned_context_window,
4545 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4546 uint8_t attn_q8[attn_q8_bytes];
4548 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4565 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4566 uint8_t ln2_q8[ln2_q8_bytes];
4568 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4571 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4574 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4575 uint8_t swiglu_q8[swiglu_q8_bytes];
4577 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4589 int aligned_embed_dim,
4590 int aligned_head_dim,
4591 int aligned_intermediate_dim,
4592 int aligned_context_window
4625 float q_token[H * aligned_head_dim];
4626 float k_token[H_kv * aligned_head_dim];
4627 float v_token[H_kv * aligned_head_dim];
4628 float attn_token[H * aligned_head_dim];
4631 float fc1_out[2 * aligned_intermediate_dim];
4632 float swiglu_out[aligned_intermediate_dim];
4644 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4648 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4649 uint8_t ln1_q8[ln1_q8_bytes];
4651 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4652 if (aligned_head_dim > head_dim) {
4653 for (
int h = 0; h < H; ++h) {
4654 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4655 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4662 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4664 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4666 for (
int h = 0; h < H_kv; ++h) {
4667 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4668 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4669 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4670 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4676 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4678 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4680 for (
int h = 0; h < H_kv; ++h) {
4681 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4682 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4683 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4684 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4698 for (
int h = 0; h < H_kv; ++h) {
4699 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4720 aligned_context_window,
4726 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4727 uint8_t attn_q8[attn_q8_bytes];
4729 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4746 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4747 uint8_t ln2_q8[ln2_q8_bytes];
4749 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4752 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4755 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4756 uint8_t swiglu_q8[swiglu_q8_bytes];
4758 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4770 int aligned_embed_dim,
4771 int aligned_head_dim,
4772 int aligned_intermediate_dim,
4773 int aligned_context_window
4806 float q_token[H * aligned_head_dim];
4807 float k_token[H_kv * aligned_head_dim];
4808 float v_token[H_kv * aligned_head_dim];
4809 float attn_token[H * aligned_head_dim];
4812 float fc1_out[2 * aligned_intermediate_dim];
4813 float swiglu_out[aligned_intermediate_dim];
4825 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4829 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4830 uint8_t ln1_q8[ln1_q8_bytes];
4832 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4833 if (aligned_head_dim > head_dim) {
4834 for (
int h = 0; h < H; ++h) {
4835 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4836 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4843 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4845 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4847 for (
int h = 0; h < H_kv; ++h) {
4848 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4849 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4850 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4851 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4857 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4859 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4861 for (
int h = 0; h < H_kv; ++h) {
4862 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4863 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4864 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4865 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4879 for (
int h = 0; h < H_kv; ++h) {
4880 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4901 aligned_context_window,
4907 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4908 uint8_t attn_q8[attn_q8_bytes];
4910 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4927 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4928 uint8_t ln2_q8[ln2_q8_bytes];
4930 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4933 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4936 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4937 uint8_t swiglu_q8[swiglu_q8_bytes];
4939 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4951 int aligned_embed_dim,
4952 int aligned_head_dim,
4953 int aligned_intermediate_dim,
4954 int aligned_context_window
4987 float q_token[H * aligned_head_dim];
4988 float k_token[H_kv * aligned_head_dim];
4989 float v_token[H_kv * aligned_head_dim];
4990 float attn_token[H * aligned_head_dim];
4993 float fc1_out[2 * aligned_intermediate_dim];
4994 float swiglu_out[aligned_intermediate_dim];
5006 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5010 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5011 uint8_t ln1_q8[ln1_q8_bytes];
5013 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5014 if (aligned_head_dim > head_dim) {
5015 for (
int h = 0; h < H; ++h) {
5016 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5017 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5024 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5026 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5028 for (
int h = 0; h < H_kv; ++h) {
5029 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5030 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5031 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5032 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5038 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5040 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5042 for (
int h = 0; h < H_kv; ++h) {
5043 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5044 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5045 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5046 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5060 for (
int h = 0; h < H_kv; ++h) {
5061 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5082 aligned_context_window,
5088 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5089 uint8_t attn_q8[attn_q8_bytes];
5091 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5108 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5109 uint8_t ln2_q8[ln2_q8_bytes];
5111 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5114 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5117 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5118 uint8_t swiglu_q8[swiglu_q8_bytes];
5120 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5132 int aligned_embed_dim,
5133 int aligned_head_dim,
5134 int aligned_intermediate_dim,
5135 int aligned_context_window
5168 float q_token[H * aligned_head_dim];
5169 float k_token[H_kv * aligned_head_dim];
5170 float v_token[H_kv * aligned_head_dim];
5171 float attn_token[H * aligned_head_dim];
5174 float fc1_out[2 * aligned_intermediate_dim];
5175 float swiglu_out[aligned_intermediate_dim];
5187 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5191 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5192 uint8_t ln1_q8[ln1_q8_bytes];
5194 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5195 if (aligned_head_dim > head_dim) {
5196 for (
int h = 0; h < H; ++h) {
5197 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5198 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5205 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5207 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5209 for (
int h = 0; h < H_kv; ++h) {
5210 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5211 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5212 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5213 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5219 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5221 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5223 for (
int h = 0; h < H_kv; ++h) {
5224 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5225 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5226 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5227 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5241 for (
int h = 0; h < H_kv; ++h) {
5242 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5263 aligned_context_window,
5269 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5270 uint8_t attn_q8[attn_q8_bytes];
5272 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5289 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5290 uint8_t ln2_q8[ln2_q8_bytes];
5292 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5295 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5298 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5299 uint8_t swiglu_q8[swiglu_q8_bytes];
5301 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5313 int aligned_embed_dim,
5314 int aligned_head_dim,
5315 int aligned_intermediate_dim,
5316 int aligned_context_window
5349 float q_token[H * aligned_head_dim];
5350 float k_token[H_kv * aligned_head_dim];
5351 float v_token[H_kv * aligned_head_dim];
5352 float attn_token[H * aligned_head_dim];
5355 float fc1_out[2 * aligned_intermediate_dim];
5356 float swiglu_out[aligned_intermediate_dim];
5368 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5372 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5373 uint8_t ln1_q8[ln1_q8_bytes];
5375 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5376 if (aligned_head_dim > head_dim) {
5377 for (
int h = 0; h < H; ++h) {
5378 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5379 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5386 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5388 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5390 for (
int h = 0; h < H_kv; ++h) {
5391 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5392 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5393 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5394 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5400 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5402 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5404 for (
int h = 0; h < H_kv; ++h) {
5405 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5406 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5407 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5408 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5422 for (
int h = 0; h < H_kv; ++h) {
5423 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5444 aligned_context_window,
5450 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5451 uint8_t attn_q8[attn_q8_bytes];
5453 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5470 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5471 uint8_t ln2_q8[ln2_q8_bytes];
5473 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5476 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5479 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5480 uint8_t swiglu_q8[swiglu_q8_bytes];
5482 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5494 int aligned_embed_dim,
5495 int aligned_head_dim,
5496 int aligned_intermediate_dim,
5497 int aligned_context_window
5530 float q_token[H * aligned_head_dim];
5531 float k_token[H_kv * aligned_head_dim];
5532 float v_token[H_kv * aligned_head_dim];
5533 float attn_token[H * aligned_head_dim];
5536 float fc1_out[2 * aligned_intermediate_dim];
5537 float swiglu_out[aligned_intermediate_dim];
5549 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5553 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5554 uint8_t ln1_q8[ln1_q8_bytes];
5556 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5557 if (aligned_head_dim > head_dim) {
5558 for (
int h = 0; h < H; ++h) {
5559 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5560 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5567 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5569 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5571 for (
int h = 0; h < H_kv; ++h) {
5572 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5573 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5574 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5575 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5581 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5583 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5585 for (
int h = 0; h < H_kv; ++h) {
5586 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5587 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5588 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5589 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5603 for (
int h = 0; h < H_kv; ++h) {
5604 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5625 aligned_context_window,
5631 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5632 uint8_t attn_q8[attn_q8_bytes];
5634 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5651 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5652 uint8_t ln2_q8[ln2_q8_bytes];
5654 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5657 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5660 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5661 uint8_t swiglu_q8[swiglu_q8_bytes];
5663 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5675 int aligned_embed_dim,
5676 int aligned_head_dim,
5677 int aligned_intermediate_dim,
5678 int aligned_context_window
5711 float q_token[H * aligned_head_dim];
5712 float k_token[H_kv * aligned_head_dim];
5713 float v_token[H_kv * aligned_head_dim];
5714 float attn_token[H * aligned_head_dim];
5717 float fc1_out[2 * aligned_intermediate_dim];
5718 float swiglu_out[aligned_intermediate_dim];
5730 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5734 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5735 uint8_t ln1_q8[ln1_q8_bytes];
5737 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5738 if (aligned_head_dim > head_dim) {
5739 for (
int h = 0; h < H; ++h) {
5740 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5741 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5748 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5750 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5752 for (
int h = 0; h < H_kv; ++h) {
5753 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5754 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5755 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5756 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5762 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5764 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5766 for (
int h = 0; h < H_kv; ++h) {
5767 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5768 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5769 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5770 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5784 for (
int h = 0; h < H_kv; ++h) {
5785 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5806 aligned_context_window,
5812 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5813 uint8_t attn_q8[attn_q8_bytes];
5815 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5832 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5833 uint8_t ln2_q8[ln2_q8_bytes];
5835 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5838 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5841 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5842 uint8_t swiglu_q8[swiglu_q8_bytes];
5844 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5856 int aligned_embed_dim,
5857 int aligned_head_dim,
5858 int aligned_intermediate_dim,
5859 int aligned_context_window
5892 float q_token[H * aligned_head_dim];
5893 float k_token[H_kv * aligned_head_dim];
5894 float v_token[H_kv * aligned_head_dim];
5895 float attn_token[H * aligned_head_dim];
5898 float fc1_out[2 * aligned_intermediate_dim];
5899 float swiglu_out[aligned_intermediate_dim];
5911 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5915 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5916 uint8_t ln1_q8[ln1_q8_bytes];
5918 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5919 if (aligned_head_dim > head_dim) {
5920 for (
int h = 0; h < H; ++h) {
5921 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5922 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5929 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5931 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5933 for (
int h = 0; h < H_kv; ++h) {
5934 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5935 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5936 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5937 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5943 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5945 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5947 for (
int h = 0; h < H_kv; ++h) {
5948 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5949 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5950 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5951 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5965 for (
int h = 0; h < H_kv; ++h) {
5966 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5987 aligned_context_window,
5993 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5994 uint8_t attn_q8[attn_q8_bytes];
5996 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6013 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6014 uint8_t ln2_q8[ln2_q8_bytes];
6016 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6019 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6022 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6023 uint8_t swiglu_q8[swiglu_q8_bytes];
6025 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6037 int aligned_embed_dim,
6038 int aligned_head_dim,
6039 int aligned_intermediate_dim,
6040 int aligned_context_window
6073 float q_token[H * aligned_head_dim];
6074 float k_token[H_kv * aligned_head_dim];
6075 float v_token[H_kv * aligned_head_dim];
6076 float attn_token[H * aligned_head_dim];
6079 float fc1_out[2 * aligned_intermediate_dim];
6080 float swiglu_out[aligned_intermediate_dim];
6092 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6096 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6097 uint8_t ln1_q8[ln1_q8_bytes];
6099 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6100 if (aligned_head_dim > head_dim) {
6101 for (
int h = 0; h < H; ++h) {
6102 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6103 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6110 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6112 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6114 for (
int h = 0; h < H_kv; ++h) {
6115 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6116 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6117 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6118 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6124 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6126 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6128 for (
int h = 0; h < H_kv; ++h) {
6129 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6130 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6131 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6132 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6146 for (
int h = 0; h < H_kv; ++h) {
6147 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6168 aligned_context_window,
6174 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6175 uint8_t attn_q8[attn_q8_bytes];
6177 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6194 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6195 uint8_t ln2_q8[ln2_q8_bytes];
6197 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6200 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6203 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6204 uint8_t swiglu_q8[swiglu_q8_bytes];
6206 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6218 int aligned_embed_dim,
6219 int aligned_head_dim,
6220 int aligned_intermediate_dim,
6221 int aligned_context_window
6254 float q_token[H * aligned_head_dim];
6255 float k_token[H_kv * aligned_head_dim];
6256 float v_token[H_kv * aligned_head_dim];
6257 float attn_token[H * aligned_head_dim];
6260 float fc1_out[2 * aligned_intermediate_dim];
6261 float swiglu_out[aligned_intermediate_dim];
6273 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6277 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6278 uint8_t ln1_q8[ln1_q8_bytes];
6280 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6281 if (aligned_head_dim > head_dim) {
6282 for (
int h = 0; h < H; ++h) {
6283 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6284 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6291 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6293 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6295 for (
int h = 0; h < H_kv; ++h) {
6296 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6297 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6298 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6299 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6305 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6307 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6309 for (
int h = 0; h < H_kv; ++h) {
6310 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6311 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6312 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6313 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6327 for (
int h = 0; h < H_kv; ++h) {
6328 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6349 aligned_context_window,
6355 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6356 uint8_t attn_q8[attn_q8_bytes];
6358 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6375 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6376 uint8_t ln2_q8[ln2_q8_bytes];
6378 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6381 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6384 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6385 uint8_t swiglu_q8[swiglu_q8_bytes];
6387 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6399 int aligned_embed_dim,
6400 int aligned_head_dim,
6401 int aligned_intermediate_dim,
6402 int aligned_context_window
6435 float q_token[H * aligned_head_dim];
6436 float k_token[H_kv * aligned_head_dim];
6437 float v_token[H_kv * aligned_head_dim];
6438 float attn_token[H * aligned_head_dim];
6441 float fc1_out[2 * aligned_intermediate_dim];
6442 float swiglu_out[aligned_intermediate_dim];
6454 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6458 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6459 uint8_t ln1_q8[ln1_q8_bytes];
6461 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6462 if (aligned_head_dim > head_dim) {
6463 for (
int h = 0; h < H; ++h) {
6464 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6465 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6472 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6474 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6476 for (
int h = 0; h < H_kv; ++h) {
6477 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6478 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6479 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6480 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6486 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6488 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6490 for (
int h = 0; h < H_kv; ++h) {
6491 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6492 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6493 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6494 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6508 for (
int h = 0; h < H_kv; ++h) {
6509 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6530 aligned_context_window,
6536 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6537 uint8_t attn_q8[attn_q8_bytes];
6539 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6556 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6557 uint8_t ln2_q8[ln2_q8_bytes];
6559 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6562 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6565 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6566 uint8_t swiglu_q8[swiglu_q8_bytes];
6568 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6580 int aligned_embed_dim,
6581 int aligned_head_dim,
6582 int aligned_intermediate_dim,
6583 int aligned_context_window
6616 float q_token[H * aligned_head_dim];
6617 float k_token[H_kv * aligned_head_dim];
6618 float v_token[H_kv * aligned_head_dim];
6619 float attn_token[H * aligned_head_dim];
6622 float fc1_out[2 * aligned_intermediate_dim];
6623 float swiglu_out[aligned_intermediate_dim];
6635 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6639 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6640 uint8_t ln1_q8[ln1_q8_bytes];
6642 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6643 if (aligned_head_dim > head_dim) {
6644 for (
int h = 0; h < H; ++h) {
6645 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6646 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6653 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6655 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6657 for (
int h = 0; h < H_kv; ++h) {
6658 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6659 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6660 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6661 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6667 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6669 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6671 for (
int h = 0; h < H_kv; ++h) {
6672 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6673 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6674 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6675 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6689 for (
int h = 0; h < H_kv; ++h) {
6690 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6711 aligned_context_window,
6717 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6718 uint8_t attn_q8[attn_q8_bytes];
6720 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6737 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6738 uint8_t ln2_q8[ln2_q8_bytes];
6740 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6743 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6746 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6747 uint8_t swiglu_q8[swiglu_q8_bytes];
6749 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6761 int aligned_embed_dim,
6762 int aligned_head_dim,
6763 int aligned_intermediate_dim,
6764 int aligned_context_window
6797 float q_token[H * aligned_head_dim];
6798 float k_token[H_kv * aligned_head_dim];
6799 float v_token[H_kv * aligned_head_dim];
6800 float attn_token[H * aligned_head_dim];
6803 float fc1_out[2 * aligned_intermediate_dim];
6804 float swiglu_out[aligned_intermediate_dim];
6816 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6820 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6821 uint8_t ln1_q8[ln1_q8_bytes];
6823 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6824 if (aligned_head_dim > head_dim) {
6825 for (
int h = 0; h < H; ++h) {
6826 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6827 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6834 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6836 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6838 for (
int h = 0; h < H_kv; ++h) {
6839 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6840 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6841 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6842 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6848 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6850 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6852 for (
int h = 0; h < H_kv; ++h) {
6853 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6854 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6855 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6856 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6870 for (
int h = 0; h < H_kv; ++h) {
6871 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6892 aligned_context_window,
6898 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6899 uint8_t attn_q8[attn_q8_bytes];
6901 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6918 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6919 uint8_t ln2_q8[ln2_q8_bytes];
6921 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6924 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6927 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6928 uint8_t swiglu_q8[swiglu_q8_bytes];
6930 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6942 int aligned_embed_dim,
6943 int aligned_head_dim,
6944 int aligned_intermediate_dim,
6945 int aligned_context_window
6978 float q_token[H * aligned_head_dim];
6979 float k_token[H_kv * aligned_head_dim];
6980 float v_token[H_kv * aligned_head_dim];
6981 float attn_token[H * aligned_head_dim];
6984 float fc1_out[2 * aligned_intermediate_dim];
6985 float swiglu_out[aligned_intermediate_dim];
6997 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7001 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7002 uint8_t ln1_q8[ln1_q8_bytes];
7004 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7005 if (aligned_head_dim > head_dim) {
7006 for (
int h = 0; h < H; ++h) {
7007 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7008 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7015 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7017 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7019 for (
int h = 0; h < H_kv; ++h) {
7020 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7021 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7022 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7023 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7029 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7031 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7033 for (
int h = 0; h < H_kv; ++h) {
7034 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7035 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7036 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7037 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7051 for (
int h = 0; h < H_kv; ++h) {
7052 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7073 aligned_context_window,
7079 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7080 uint8_t attn_q8[attn_q8_bytes];
7082 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7099 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7100 uint8_t ln2_q8[ln2_q8_bytes];
7102 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7105 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7108 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7109 uint8_t swiglu_q8[swiglu_q8_bytes];
7111 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7123 int aligned_embed_dim,
7124 int aligned_head_dim,
7125 int aligned_intermediate_dim,
7126 int aligned_context_window
7159 float q_token[H * aligned_head_dim];
7160 float k_token[H_kv * aligned_head_dim];
7161 float v_token[H_kv * aligned_head_dim];
7162 float attn_token[H * aligned_head_dim];
7165 float fc1_out[2 * aligned_intermediate_dim];
7166 float swiglu_out[aligned_intermediate_dim];
7178 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7182 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7183 uint8_t ln1_q8[ln1_q8_bytes];
7185 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7186 if (aligned_head_dim > head_dim) {
7187 for (
int h = 0; h < H; ++h) {
7188 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7189 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7196 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7198 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7200 for (
int h = 0; h < H_kv; ++h) {
7201 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7202 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7203 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7204 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7210 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7212 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7214 for (
int h = 0; h < H_kv; ++h) {
7215 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7216 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7217 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7218 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7232 for (
int h = 0; h < H_kv; ++h) {
7233 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7254 aligned_context_window,
7260 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7261 uint8_t attn_q8[attn_q8_bytes];
7263 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7280 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7281 uint8_t ln2_q8[ln2_q8_bytes];
7283 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7286 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7289 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7290 uint8_t swiglu_q8[swiglu_q8_bytes];
7292 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7304 int aligned_embed_dim,
7305 int aligned_head_dim,
7306 int aligned_intermediate_dim,
7307 int aligned_context_window
7340 float q_token[H * aligned_head_dim];
7341 float k_token[H_kv * aligned_head_dim];
7342 float v_token[H_kv * aligned_head_dim];
7343 float attn_token[H * aligned_head_dim];
7346 float fc1_out[2 * aligned_intermediate_dim];
7347 float swiglu_out[aligned_intermediate_dim];
7359 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7363 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7364 uint8_t ln1_q8[ln1_q8_bytes];
7366 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7367 if (aligned_head_dim > head_dim) {
7368 for (
int h = 0; h < H; ++h) {
7369 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7370 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7377 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7379 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7381 for (
int h = 0; h < H_kv; ++h) {
7382 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7383 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7384 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7385 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7391 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7393 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7395 for (
int h = 0; h < H_kv; ++h) {
7396 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7397 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7398 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7399 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7413 for (
int h = 0; h < H_kv; ++h) {
7414 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7435 aligned_context_window,
7441 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7442 uint8_t attn_q8[attn_q8_bytes];
7444 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7461 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7462 uint8_t ln2_q8[ln2_q8_bytes];
7464 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7467 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7470 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7471 uint8_t swiglu_q8[swiglu_q8_bytes];
7473 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7485 int aligned_embed_dim,
7486 int aligned_head_dim,
7487 int aligned_intermediate_dim,
7488 int aligned_context_window
7521 float q_token[H * aligned_head_dim];
7522 float k_token[H_kv * aligned_head_dim];
7523 float v_token[H_kv * aligned_head_dim];
7524 float attn_token[H * aligned_head_dim];
7527 float fc1_out[2 * aligned_intermediate_dim];
7528 float swiglu_out[aligned_intermediate_dim];
7540 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7544 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7545 uint8_t ln1_q8[ln1_q8_bytes];
7547 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7548 if (aligned_head_dim > head_dim) {
7549 for (
int h = 0; h < H; ++h) {
7550 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7551 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7558 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7560 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7562 for (
int h = 0; h < H_kv; ++h) {
7563 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7564 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7565 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7566 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7572 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7574 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7576 for (
int h = 0; h < H_kv; ++h) {
7577 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7578 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7579 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7580 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7594 for (
int h = 0; h < H_kv; ++h) {
7595 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7616 aligned_context_window,
7622 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7623 uint8_t attn_q8[attn_q8_bytes];
7625 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7642 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7643 uint8_t ln2_q8[ln2_q8_bytes];
7645 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7648 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7651 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7652 uint8_t swiglu_q8[swiglu_q8_bytes];
7654 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7666 int aligned_embed_dim,
7667 int aligned_head_dim,
7668 int aligned_intermediate_dim,
7669 int aligned_context_window
7702 float q_token[H * aligned_head_dim];
7703 float k_token[H_kv * aligned_head_dim];
7704 float v_token[H_kv * aligned_head_dim];
7705 float attn_token[H * aligned_head_dim];
7708 float fc1_out[2 * aligned_intermediate_dim];
7709 float swiglu_out[aligned_intermediate_dim];
7721 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7725 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7726 uint8_t ln1_q8[ln1_q8_bytes];
7728 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7729 if (aligned_head_dim > head_dim) {
7730 for (
int h = 0; h < H; ++h) {
7731 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7732 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7739 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7741 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7743 for (
int h = 0; h < H_kv; ++h) {
7744 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7745 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7746 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7747 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7753 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7755 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7757 for (
int h = 0; h < H_kv; ++h) {
7758 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7759 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7760 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7761 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7775 for (
int h = 0; h < H_kv; ++h) {
7776 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7797 aligned_context_window,
7803 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7804 uint8_t attn_q8[attn_q8_bytes];
7806 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7823 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7824 uint8_t ln2_q8[ln2_q8_bytes];
7826 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7829 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7832 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7833 uint8_t swiglu_q8[swiglu_q8_bytes];
7835 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7847 int aligned_embed_dim,
7848 int aligned_head_dim,
7849 int aligned_intermediate_dim,
7850 int aligned_context_window
7883 float q_token[H * aligned_head_dim];
7884 float k_token[H_kv * aligned_head_dim];
7885 float v_token[H_kv * aligned_head_dim];
7886 float attn_token[H * aligned_head_dim];
7889 float fc1_out[2 * aligned_intermediate_dim];
7890 float swiglu_out[aligned_intermediate_dim];
7902 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7906 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7907 uint8_t ln1_q8[ln1_q8_bytes];
7909 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7910 if (aligned_head_dim > head_dim) {
7911 for (
int h = 0; h < H; ++h) {
7912 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7913 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7920 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7922 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7924 for (
int h = 0; h < H_kv; ++h) {
7925 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7926 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7927 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7928 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7934 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7936 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7938 for (
int h = 0; h < H_kv; ++h) {
7939 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7940 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7941 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7942 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7956 for (
int h = 0; h < H_kv; ++h) {
7957 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7978 aligned_context_window,
7984 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7985 uint8_t attn_q8[attn_q8_bytes];
7987 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8004 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8005 uint8_t ln2_q8[ln2_q8_bytes];
8007 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8010 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8013 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8014 uint8_t swiglu_q8[swiglu_q8_bytes];
8016 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8028 int aligned_embed_dim,
8029 int aligned_head_dim,
8030 int aligned_intermediate_dim,
8031 int aligned_context_window
8064 float q_token[H * aligned_head_dim];
8065 float k_token[H_kv * aligned_head_dim];
8066 float v_token[H_kv * aligned_head_dim];
8067 float attn_token[H * aligned_head_dim];
8070 float fc1_out[2 * aligned_intermediate_dim];
8071 float swiglu_out[aligned_intermediate_dim];
8083 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8087 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8088 uint8_t ln1_q8[ln1_q8_bytes];
8090 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8091 if (aligned_head_dim > head_dim) {
8092 for (
int h = 0; h < H; ++h) {
8093 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8094 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8101 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8103 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8105 for (
int h = 0; h < H_kv; ++h) {
8106 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8107 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8108 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8109 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8115 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8117 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8119 for (
int h = 0; h < H_kv; ++h) {
8120 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8121 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8122 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8123 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8137 for (
int h = 0; h < H_kv; ++h) {
8138 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8159 aligned_context_window,
8165 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8166 uint8_t attn_q8[attn_q8_bytes];
8168 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8185 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8186 uint8_t ln2_q8[ln2_q8_bytes];
8188 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8191 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8194 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8195 uint8_t swiglu_q8[swiglu_q8_bytes];
8197 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8209 int aligned_embed_dim,
8210 int aligned_head_dim,
8211 int aligned_intermediate_dim,
8212 int aligned_context_window
8245 float q_token[H * aligned_head_dim];
8246 float k_token[H_kv * aligned_head_dim];
8247 float v_token[H_kv * aligned_head_dim];
8248 float attn_token[H * aligned_head_dim];
8251 float fc1_out[2 * aligned_intermediate_dim];
8252 float swiglu_out[aligned_intermediate_dim];
8264 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8268 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8269 uint8_t ln1_q8[ln1_q8_bytes];
8271 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8272 if (aligned_head_dim > head_dim) {
8273 for (
int h = 0; h < H; ++h) {
8274 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8275 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8282 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8284 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8286 for (
int h = 0; h < H_kv; ++h) {
8287 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8288 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8289 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8290 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8296 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8298 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8300 for (
int h = 0; h < H_kv; ++h) {
8301 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8302 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8303 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8304 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8318 for (
int h = 0; h < H_kv; ++h) {
8319 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8340 aligned_context_window,
8346 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8347 uint8_t attn_q8[attn_q8_bytes];
8349 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8366 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8367 uint8_t ln2_q8[ln2_q8_bytes];
8369 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8372 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8375 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8376 uint8_t swiglu_q8[swiglu_q8_bytes];
8378 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8390 int aligned_embed_dim,
8391 int aligned_head_dim,
8392 int aligned_intermediate_dim,
8393 int aligned_context_window
8426 float q_token[H * aligned_head_dim];
8427 float k_token[H_kv * aligned_head_dim];
8428 float v_token[H_kv * aligned_head_dim];
8429 float attn_token[H * aligned_head_dim];
8432 float fc1_out[2 * aligned_intermediate_dim];
8433 float swiglu_out[aligned_intermediate_dim];
8445 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8449 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8450 uint8_t ln1_q8[ln1_q8_bytes];
8452 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8453 if (aligned_head_dim > head_dim) {
8454 for (
int h = 0; h < H; ++h) {
8455 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8456 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8463 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8465 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8467 for (
int h = 0; h < H_kv; ++h) {
8468 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8469 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8470 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8471 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8477 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8479 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8481 for (
int h = 0; h < H_kv; ++h) {
8482 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8483 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8484 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8485 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8499 for (
int h = 0; h < H_kv; ++h) {
8500 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8521 aligned_context_window,
8527 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8528 uint8_t attn_q8[attn_q8_bytes];
8530 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8547 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8548 uint8_t ln2_q8[ln2_q8_bytes];
8550 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8553 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8556 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8557 uint8_t swiglu_q8[swiglu_q8_bytes];
8559 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8571 int aligned_embed_dim,
8572 int aligned_head_dim,
8573 int aligned_intermediate_dim,
8574 int aligned_context_window
8607 float q_token[H * aligned_head_dim];
8608 float k_token[H_kv * aligned_head_dim];
8609 float v_token[H_kv * aligned_head_dim];
8610 float attn_token[H * aligned_head_dim];
8613 float fc1_out[2 * aligned_intermediate_dim];
8614 float swiglu_out[aligned_intermediate_dim];
8626 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8630 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8631 uint8_t ln1_q8[ln1_q8_bytes];
8633 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8634 if (aligned_head_dim > head_dim) {
8635 for (
int h = 0; h < H; ++h) {
8636 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8637 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8644 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8646 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8648 for (
int h = 0; h < H_kv; ++h) {
8649 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8650 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8651 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8652 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8658 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8660 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8662 for (
int h = 0; h < H_kv; ++h) {
8663 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8664 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8665 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8666 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8680 for (
int h = 0; h < H_kv; ++h) {
8681 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8702 aligned_context_window,
8708 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8709 uint8_t attn_q8[attn_q8_bytes];
8711 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8728 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8729 uint8_t ln2_q8[ln2_q8_bytes];
8731 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8734 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8737 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8738 uint8_t swiglu_q8[swiglu_q8_bytes];
8740 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8755 if (!model || !
token)
return;
8757 const int aligned_embed_dim = 1024;
8758 const int aligned_head_dim = 64;
8759 const int aligned_intermediate_dim = 4864;
8760 const int aligned_context_window = 131072;
8762 if (token_index < 0 || token_index >= aligned_context_window)
return;
8822 const size_t final_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8823 uint8_t final_q8[final_q8_bytes];
8837 if (!model || !tokens || num_tokens <= 0)
return;
8865 .model_name =
"qwen2_0.5b_decode",
8866 .model_family =
"qwen2",
8875 if (!model)
return NULL;
Generic Model API - Model-agnostic interface for CK-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void quantize_row_q8_k(const float *x, void *y, int k)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
#define QWEN2_0_5B_DECODE_TOTAL_BYTES
#define QWEN2_0_5B_DECODE_HEAD_DIM
#define QWEN2_0_5B_DECODE_PTR(model, offset)
#define QWEN2_0_5B_DECODE_ACTIVATION_BYTES
static const QWEN2_0_5B_DECODEFooterOffsets QWEN2_0_5B_DECODE_FOOTER
#define QWEN2_0_5B_DECODE_INTERMEDIATE
static const QWEN2_0_5B_DECODELayerOffsets QWEN2_0_5B_DECODE_LAYERS[24]
#define QWEN2_0_5B_DECODE_DTYPE_BYTES
#define QWEN2_0_5B_DECODE_EMBED_DIM
#define QWEN2_0_5B_DECODE_MAGIC
#define QWEN2_0_5B_DECODE_CANARY_COUNT
#define QWEN2_0_5B_DECODE_MAX_SEQ_LEN
#define QWEN2_0_5B_DECODE_NUM_LAYERS
#define QWEN2_0_5B_DECODE_WEIGHT_BYTES
#define QWEN2_0_5B_DECODE_CANARY_VALUE
static const QWEN2_0_5B_DECODEGlobalOffsets QWEN2_0_5B_DECODE_GLOBALS
#define QWEN2_0_5B_DECODE_NUM_KV_HEADS
#define QWEN2_0_5B_DECODE_NUM_HEADS
#define QWEN2_0_5B_DECODE_CANARY_SIZE
#define QWEN2_0_5B_DECODE_VOCAB_SIZE
static const QWEN2_0_5B_DECODEHeaderOffsets QWEN2_0_5B_DECODE_HEADER
static const QWEN2_0_5B_DECODECanary QWEN2_0_5B_DECODE_CANARIES[]
static void qwen2_0_5b_decode_layer_18_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_1_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_3_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
const CKModelConfig * ck_model_get_config(void)
static void qwen2_0_5b_decode_layer_18_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_16_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_20_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_precompute_rope(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_layer_0_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int qwen2_0_5b_decode_verify_canaries(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
static void qwen2_0_5b_decode_layer_4_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void * ck_model_create(void)
static void qwen2_0_5b_decode_layer_7_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static int qwen2_0_5b_decode_align_elems(int elems, int elem_bytes, int align_bytes)
static CKModelConfig g_model_config
static void qwen2_0_5b_decode_layer_6_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_8_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_5_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_13_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int qwen2_0_5b_decode_model_allocate(QWEN2_0_5B_DECODEModel *model)
void ck_model_precompute_rope(void *model)
static void qwen2_0_5b_decode_layer_0_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_19_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_6_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_7_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_decode(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
void ck_model_forward(void *model, const int *tokens, int num_tokens)
static void qwen2_0_5b_decode_layer_15_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_2_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int ck_model_verify_canaries(void *model)
static void qwen2_0_5b_decode_layer_17_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_3_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_forward_prefill_impl(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
static void qwen2_0_5b_decode_layer_13_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void ck_model_free(void *model)
static void qwen2_0_5b_decode_layer_21_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_4_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void * ck_model_get_base(void *model)
void ck_model_decode(void *model, const int *token, int token_index)
static void qwen2_0_5b_decode_layer_9_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_decode_token(QWEN2_0_5B_DECODEModel *model, const int *token, int token_index)
static void qwen2_0_5b_decode_layer_21_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_22_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
size_t ck_model_get_total_bytes(void *model)
static void qwen2_0_5b_decode_layer_23_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
float * ck_model_get_logits(void *model)
static void qwen2_0_5b_decode_layer_16_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
struct __attribute__((packed))
static void qwen2_0_5b_decode_layer_1_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_forward(QWEN2_0_5B_DECODEModel *model, const int *tokens, int num_tokens)
static void qwen2_0_5b_decode_layer_5_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_15_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_11_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_14_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_10_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_12_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_9_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void qwen2_0_5b_decode_layer_23_prefill(QWEN2_0_5B_DECODEModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void qwen2_0_5b_decode_model_free(QWEN2_0_5B_DECODEModel *model)
static void qwen2_0_5b_decode_layer_20_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
_Static_assert(sizeof(MagicHeader)==64, "MagicHeader must be 64 bytes")
static void qwen2_0_5b_decode_layer_17_decode(QWEN2_0_5B_DECODEModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
AUTO-GENERATED: qwen2_0.5b_decode Memory Layout.