35 #if MODEL_DTYPE_BYTES != 4
36 #error "model: v6.5 codegen currently supports fp32 only. Use --dtype=fp32."
50 if (!a || !b || !out) {
53 for (
int t = 0; t < tokens; ++t) {
54 const float *pa = a + (size_t)t * (
size_t)aligned_embed_dim;
55 const float *pb = b + (size_t)t * (
size_t)aligned_embed_dim;
56 float *pc = out + (size_t)t * (
size_t)aligned_embed_dim;
57 for (
int d = 0; d < aligned_embed_dim; ++d) {
58 pc[d] = pa[d] + pb[d];
71 uint64_t weight_bytes;
72 uint64_t activation_bytes;
78 uint32_t canary_count;
89 size_t total = MODEL_TOTAL_BYTES;
92 model->base = mmap(NULL, total,
93 PROT_READ | PROT_WRITE,
94 MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
96 if (model->base == MAP_FAILED) {
97 model->base = mmap(NULL, total,
98 PROT_READ | PROT_WRITE,
99 MAP_PRIVATE | MAP_ANONYMOUS,
102 if (model->base == MAP_FAILED) {
103 perror(
"mmap failed");
107 model->base = aligned_alloc(64, total);
109 perror(
"aligned_alloc failed");
114 model->total_bytes = total;
118 header->magic = MODEL_MAGIC;
120 header->total_bytes = MODEL_TOTAL_BYTES;
121 header->weight_bytes = MODEL_WEIGHT_BYTES;
122 header->activation_bytes = MODEL_ACTIVATION_BYTES;
128 header->canary_count = MODEL_CANARY_COUNT;
131 for (
int i = 0; i < MODEL_CANARY_COUNT; i++) {
132 uint32_t *ptr = (uint32_t*)((
char*)model->base + MODEL_CANARIES[i].offset);
133 for (
int j = 0; j < (MODEL_CANARY_SIZE / 4); j++) {
134 ptr[j] = MODEL_CANARY_VALUE;
142 if (!model || !model->base)
return;
144 munmap(model->base, model->total_bytes);
149 model->total_bytes = 0;
156 for (
int i = 0; i < MODEL_CANARY_COUNT; i++) {
157 ptr = (uint32_t*)((
char*)model->base + MODEL_CANARIES[i].offset);
158 for (
int j = 0; j < 4; j++) {
159 if (ptr[j] != MODEL_CANARY_VALUE) {
160 fprintf(stderr,
"CANARY CORRUPTION: %s at offset 0x%lX\n",
161 MODEL_CANARIES[i].name,
162 MODEL_CANARIES[i].offset);
177 int bytes = elems * elem_bytes;
178 int aligned = (bytes + align_bytes - 1) / align_bytes * align_bytes;
179 return aligned / elem_bytes;
189 const float theta = 1000000.0f;
191 float *cos_ptr = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
192 float *sin_ptr = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
194 for (
int pos = 0; pos < T; pos++) {
195 for (
int i = 0; i < D; i++) {
196 float freq = 1.0f / powf(theta, (
float)(2 * i) / (
float)(D * 2));
197 float angle = (float)pos * freq;
198 cos_ptr[pos * D + i] = cosf(angle);
199 sin_ptr[pos * D + i] = sinf(angle);
214 int aligned_embed_dim,
215 int aligned_head_dim,
216 int aligned_intermediate_dim,
217 int aligned_context_window
219 const MODELLayerOffsets *L = &MODEL_LAYERS[0];
221 float *input = MODEL_PTR(model, MODEL_HEADER.embedded_input);
222 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
223 float *ln1_out = MODEL_PTR(model, L->ln1_out);
224 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
225 float *ln2_out = MODEL_PTR(model, L->ln2_out);
226 float *q = MODEL_PTR(model, L->q);
227 float *k = MODEL_PTR(model, L->k);
228 float *v = MODEL_PTR(model, L->v);
229 float *attn_out = MODEL_PTR(model, L->attn_out);
230 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
231 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
232 float *residual1 = MODEL_PTR(model, L->residual1);
233 float *fc1_out = MODEL_PTR(model, L->fc1_out);
234 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
235 float *mlp_out = MODEL_PTR(model, L->mlp_out);
236 float *output = MODEL_PTR(model, L->output);
238 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
239 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
240 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
241 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
242 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
243 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
244 const float *BQ = NULL;
245 const float *BK = NULL;
246 const float *BV = NULL;
247 const float *BO = NULL;
248 const float *B1 = NULL;
249 const float *B2 = NULL;
251 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
252 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
257 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
258 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
259 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
273 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
274 for (
int h = 0; h < H; ++h) {
275 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
276 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
277 float *q_h = q + (size_t)h * q_head_stride;
278 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
283 const uint8_t *WK_bytes = (
const uint8_t *)WK;
284 for (
int h = 0; h < H_kv; ++h) {
285 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
286 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
287 float *k_h = k + (size_t)h * kv_head_stride;
288 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
293 const uint8_t *WV_bytes = (
const uint8_t *)WV;
294 for (
int h = 0; h < H_kv; ++h) {
295 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
296 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
297 float *v_h = v + (size_t)h * kv_head_stride;
298 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
313 aligned_context_window);
325 aligned_context_window);
328 const int K = H * aligned_head_dim;
329 if (K != aligned_embed_dim) {
332 const float *proj_in = attn_out;
337 for (
int t = 0; t < num_tokens; ++t) {
338 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
339 for (
int h = 0; h < H; ++h) {
340 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
341 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
343 (
size_t)aligned_head_dim *
sizeof(
float));
346 proj_in = proj_scratch;
348 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
364 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
365 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
366 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
378 int aligned_embed_dim,
379 int aligned_head_dim,
380 int aligned_intermediate_dim,
381 int aligned_context_window
383 const MODELLayerOffsets *L = &MODEL_LAYERS[1];
385 float *input = MODEL_PTR(model, MODEL_LAYERS[0].output);
386 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
387 float *ln1_out = MODEL_PTR(model, L->ln1_out);
388 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
389 float *ln2_out = MODEL_PTR(model, L->ln2_out);
390 float *q = MODEL_PTR(model, L->q);
391 float *k = MODEL_PTR(model, L->k);
392 float *v = MODEL_PTR(model, L->v);
393 float *attn_out = MODEL_PTR(model, L->attn_out);
394 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
395 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
396 float *residual1 = MODEL_PTR(model, L->residual1);
397 float *fc1_out = MODEL_PTR(model, L->fc1_out);
398 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
399 float *mlp_out = MODEL_PTR(model, L->mlp_out);
400 float *output = MODEL_PTR(model, L->output);
402 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
403 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
404 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
405 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
406 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
407 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
408 const float *BQ = NULL;
409 const float *BK = NULL;
410 const float *BV = NULL;
411 const float *BO = NULL;
412 const float *B1 = NULL;
413 const float *B2 = NULL;
415 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
416 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
421 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
422 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
423 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
437 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
438 for (
int h = 0; h < H; ++h) {
439 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
440 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
441 float *q_h = q + (size_t)h * q_head_stride;
442 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
447 const uint8_t *WK_bytes = (
const uint8_t *)WK;
448 for (
int h = 0; h < H_kv; ++h) {
449 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
450 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
451 float *k_h = k + (size_t)h * kv_head_stride;
452 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
457 const uint8_t *WV_bytes = (
const uint8_t *)WV;
458 for (
int h = 0; h < H_kv; ++h) {
459 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
460 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
461 float *v_h = v + (size_t)h * kv_head_stride;
462 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
477 aligned_context_window);
489 aligned_context_window);
492 const int K = H * aligned_head_dim;
493 if (K != aligned_embed_dim) {
496 const float *proj_in = attn_out;
501 for (
int t = 0; t < num_tokens; ++t) {
502 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
503 for (
int h = 0; h < H; ++h) {
504 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
505 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
507 (
size_t)aligned_head_dim *
sizeof(
float));
510 proj_in = proj_scratch;
512 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
528 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
529 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
530 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
542 int aligned_embed_dim,
543 int aligned_head_dim,
544 int aligned_intermediate_dim,
545 int aligned_context_window
547 const MODELLayerOffsets *L = &MODEL_LAYERS[2];
549 float *input = MODEL_PTR(model, MODEL_LAYERS[1].output);
550 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
551 float *ln1_out = MODEL_PTR(model, L->ln1_out);
552 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
553 float *ln2_out = MODEL_PTR(model, L->ln2_out);
554 float *q = MODEL_PTR(model, L->q);
555 float *k = MODEL_PTR(model, L->k);
556 float *v = MODEL_PTR(model, L->v);
557 float *attn_out = MODEL_PTR(model, L->attn_out);
558 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
559 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
560 float *residual1 = MODEL_PTR(model, L->residual1);
561 float *fc1_out = MODEL_PTR(model, L->fc1_out);
562 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
563 float *mlp_out = MODEL_PTR(model, L->mlp_out);
564 float *output = MODEL_PTR(model, L->output);
566 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
567 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
568 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
569 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
570 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
571 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
572 const float *BQ = NULL;
573 const float *BK = NULL;
574 const float *BV = NULL;
575 const float *BO = NULL;
576 const float *B1 = NULL;
577 const float *B2 = NULL;
579 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
580 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
585 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
586 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
587 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
601 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
602 for (
int h = 0; h < H; ++h) {
603 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
604 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
605 float *q_h = q + (size_t)h * q_head_stride;
606 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
611 const uint8_t *WK_bytes = (
const uint8_t *)WK;
612 for (
int h = 0; h < H_kv; ++h) {
613 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
614 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
615 float *k_h = k + (size_t)h * kv_head_stride;
616 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
621 const uint8_t *WV_bytes = (
const uint8_t *)WV;
622 for (
int h = 0; h < H_kv; ++h) {
623 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
624 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
625 float *v_h = v + (size_t)h * kv_head_stride;
626 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
641 aligned_context_window);
653 aligned_context_window);
656 const int K = H * aligned_head_dim;
657 if (K != aligned_embed_dim) {
660 const float *proj_in = attn_out;
665 for (
int t = 0; t < num_tokens; ++t) {
666 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
667 for (
int h = 0; h < H; ++h) {
668 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
669 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
671 (
size_t)aligned_head_dim *
sizeof(
float));
674 proj_in = proj_scratch;
676 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
692 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
693 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
694 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
706 int aligned_embed_dim,
707 int aligned_head_dim,
708 int aligned_intermediate_dim,
709 int aligned_context_window
711 const MODELLayerOffsets *L = &MODEL_LAYERS[3];
713 float *input = MODEL_PTR(model, MODEL_LAYERS[2].output);
714 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
715 float *ln1_out = MODEL_PTR(model, L->ln1_out);
716 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
717 float *ln2_out = MODEL_PTR(model, L->ln2_out);
718 float *q = MODEL_PTR(model, L->q);
719 float *k = MODEL_PTR(model, L->k);
720 float *v = MODEL_PTR(model, L->v);
721 float *attn_out = MODEL_PTR(model, L->attn_out);
722 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
723 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
724 float *residual1 = MODEL_PTR(model, L->residual1);
725 float *fc1_out = MODEL_PTR(model, L->fc1_out);
726 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
727 float *mlp_out = MODEL_PTR(model, L->mlp_out);
728 float *output = MODEL_PTR(model, L->output);
730 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
731 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
732 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
733 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
734 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
735 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
736 const float *BQ = NULL;
737 const float *BK = NULL;
738 const float *BV = NULL;
739 const float *BO = NULL;
740 const float *B1 = NULL;
741 const float *B2 = NULL;
743 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
744 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
749 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
750 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
751 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
765 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
766 for (
int h = 0; h < H; ++h) {
767 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
768 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
769 float *q_h = q + (size_t)h * q_head_stride;
770 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
775 const uint8_t *WK_bytes = (
const uint8_t *)WK;
776 for (
int h = 0; h < H_kv; ++h) {
777 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
778 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
779 float *k_h = k + (size_t)h * kv_head_stride;
780 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
785 const uint8_t *WV_bytes = (
const uint8_t *)WV;
786 for (
int h = 0; h < H_kv; ++h) {
787 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
788 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
789 float *v_h = v + (size_t)h * kv_head_stride;
790 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
805 aligned_context_window);
817 aligned_context_window);
820 const int K = H * aligned_head_dim;
821 if (K != aligned_embed_dim) {
824 const float *proj_in = attn_out;
829 for (
int t = 0; t < num_tokens; ++t) {
830 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
831 for (
int h = 0; h < H; ++h) {
832 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
833 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
835 (
size_t)aligned_head_dim *
sizeof(
float));
838 proj_in = proj_scratch;
840 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
856 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
857 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
858 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
870 int aligned_embed_dim,
871 int aligned_head_dim,
872 int aligned_intermediate_dim,
873 int aligned_context_window
875 const MODELLayerOffsets *L = &MODEL_LAYERS[4];
877 float *input = MODEL_PTR(model, MODEL_LAYERS[3].output);
878 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
879 float *ln1_out = MODEL_PTR(model, L->ln1_out);
880 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
881 float *ln2_out = MODEL_PTR(model, L->ln2_out);
882 float *q = MODEL_PTR(model, L->q);
883 float *k = MODEL_PTR(model, L->k);
884 float *v = MODEL_PTR(model, L->v);
885 float *attn_out = MODEL_PTR(model, L->attn_out);
886 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
887 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
888 float *residual1 = MODEL_PTR(model, L->residual1);
889 float *fc1_out = MODEL_PTR(model, L->fc1_out);
890 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
891 float *mlp_out = MODEL_PTR(model, L->mlp_out);
892 float *output = MODEL_PTR(model, L->output);
894 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
895 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
896 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
897 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
898 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
899 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
900 const float *BQ = NULL;
901 const float *BK = NULL;
902 const float *BV = NULL;
903 const float *BO = NULL;
904 const float *B1 = NULL;
905 const float *B2 = NULL;
907 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
908 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
913 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
914 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
915 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
929 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
930 for (
int h = 0; h < H; ++h) {
931 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
932 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
933 float *q_h = q + (size_t)h * q_head_stride;
934 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
939 const uint8_t *WK_bytes = (
const uint8_t *)WK;
940 for (
int h = 0; h < H_kv; ++h) {
941 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
942 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
943 float *k_h = k + (size_t)h * kv_head_stride;
944 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
949 const uint8_t *WV_bytes = (
const uint8_t *)WV;
950 for (
int h = 0; h < H_kv; ++h) {
951 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
952 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
953 float *v_h = v + (size_t)h * kv_head_stride;
954 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
969 aligned_context_window);
981 aligned_context_window);
984 const int K = H * aligned_head_dim;
985 if (K != aligned_embed_dim) {
988 const float *proj_in = attn_out;
993 for (
int t = 0; t < num_tokens; ++t) {
994 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
995 for (
int h = 0; h < H; ++h) {
996 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
997 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
999 (
size_t)aligned_head_dim *
sizeof(
float));
1002 proj_in = proj_scratch;
1004 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1020 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1021 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1022 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1034 int aligned_embed_dim,
1035 int aligned_head_dim,
1036 int aligned_intermediate_dim,
1037 int aligned_context_window
1039 const MODELLayerOffsets *L = &MODEL_LAYERS[5];
1041 float *input = MODEL_PTR(model, MODEL_LAYERS[4].output);
1042 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1043 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1044 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1045 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1046 float *q = MODEL_PTR(model, L->q);
1047 float *k = MODEL_PTR(model, L->k);
1048 float *v = MODEL_PTR(model, L->v);
1049 float *attn_out = MODEL_PTR(model, L->attn_out);
1050 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1051 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1052 float *residual1 = MODEL_PTR(model, L->residual1);
1053 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1054 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1055 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1056 float *output = MODEL_PTR(model, L->output);
1058 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1059 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1060 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1061 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1062 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1063 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1064 const float *BQ = NULL;
1065 const float *BK = NULL;
1066 const float *BV = NULL;
1067 const float *BO = NULL;
1068 const float *B1 = NULL;
1069 const float *B2 = NULL;
1071 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1072 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1077 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1078 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1079 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1093 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1094 for (
int h = 0; h < H; ++h) {
1095 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1096 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1097 float *q_h = q + (size_t)h * q_head_stride;
1098 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1103 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1104 for (
int h = 0; h < H_kv; ++h) {
1105 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1106 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1107 float *k_h = k + (size_t)h * kv_head_stride;
1108 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1113 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1114 for (
int h = 0; h < H_kv; ++h) {
1115 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1116 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1117 float *v_h = v + (size_t)h * kv_head_stride;
1118 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1133 aligned_context_window);
1145 aligned_context_window);
1148 const int K = H * aligned_head_dim;
1149 if (K != aligned_embed_dim) {
1152 const float *proj_in = attn_out;
1154 if (!proj_scratch) {
1157 for (
int t = 0; t < num_tokens; ++t) {
1158 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1159 for (
int h = 0; h < H; ++h) {
1160 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1161 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1163 (
size_t)aligned_head_dim *
sizeof(
float));
1166 proj_in = proj_scratch;
1168 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1184 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1185 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1186 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1198 int aligned_embed_dim,
1199 int aligned_head_dim,
1200 int aligned_intermediate_dim,
1201 int aligned_context_window
1203 const MODELLayerOffsets *L = &MODEL_LAYERS[6];
1205 float *input = MODEL_PTR(model, MODEL_LAYERS[5].output);
1206 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1207 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1208 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1209 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1210 float *q = MODEL_PTR(model, L->q);
1211 float *k = MODEL_PTR(model, L->k);
1212 float *v = MODEL_PTR(model, L->v);
1213 float *attn_out = MODEL_PTR(model, L->attn_out);
1214 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1215 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1216 float *residual1 = MODEL_PTR(model, L->residual1);
1217 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1218 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1219 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1220 float *output = MODEL_PTR(model, L->output);
1222 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1223 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1224 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1225 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1226 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1227 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1228 const float *BQ = NULL;
1229 const float *BK = NULL;
1230 const float *BV = NULL;
1231 const float *BO = NULL;
1232 const float *B1 = NULL;
1233 const float *B2 = NULL;
1235 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1236 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1241 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1242 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1243 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1257 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1258 for (
int h = 0; h < H; ++h) {
1259 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1260 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1261 float *q_h = q + (size_t)h * q_head_stride;
1262 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1267 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1268 for (
int h = 0; h < H_kv; ++h) {
1269 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1270 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1271 float *k_h = k + (size_t)h * kv_head_stride;
1272 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1277 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1278 for (
int h = 0; h < H_kv; ++h) {
1279 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1280 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1281 float *v_h = v + (size_t)h * kv_head_stride;
1282 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1297 aligned_context_window);
1309 aligned_context_window);
1312 const int K = H * aligned_head_dim;
1313 if (K != aligned_embed_dim) {
1316 const float *proj_in = attn_out;
1318 if (!proj_scratch) {
1321 for (
int t = 0; t < num_tokens; ++t) {
1322 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1323 for (
int h = 0; h < H; ++h) {
1324 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1325 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1327 (
size_t)aligned_head_dim *
sizeof(
float));
1330 proj_in = proj_scratch;
1332 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1348 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1349 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1350 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1362 int aligned_embed_dim,
1363 int aligned_head_dim,
1364 int aligned_intermediate_dim,
1365 int aligned_context_window
1367 const MODELLayerOffsets *L = &MODEL_LAYERS[7];
1369 float *input = MODEL_PTR(model, MODEL_LAYERS[6].output);
1370 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1371 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1372 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1373 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1374 float *q = MODEL_PTR(model, L->q);
1375 float *k = MODEL_PTR(model, L->k);
1376 float *v = MODEL_PTR(model, L->v);
1377 float *attn_out = MODEL_PTR(model, L->attn_out);
1378 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1379 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1380 float *residual1 = MODEL_PTR(model, L->residual1);
1381 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1382 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1383 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1384 float *output = MODEL_PTR(model, L->output);
1386 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1387 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1388 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1389 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1390 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1391 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1392 const float *BQ = NULL;
1393 const float *BK = NULL;
1394 const float *BV = NULL;
1395 const float *BO = NULL;
1396 const float *B1 = NULL;
1397 const float *B2 = NULL;
1399 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1400 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1405 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1406 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1407 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1421 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1422 for (
int h = 0; h < H; ++h) {
1423 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1424 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1425 float *q_h = q + (size_t)h * q_head_stride;
1426 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1431 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1432 for (
int h = 0; h < H_kv; ++h) {
1433 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1434 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1435 float *k_h = k + (size_t)h * kv_head_stride;
1436 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1441 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1442 for (
int h = 0; h < H_kv; ++h) {
1443 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1444 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1445 float *v_h = v + (size_t)h * kv_head_stride;
1446 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1461 aligned_context_window);
1473 aligned_context_window);
1476 const int K = H * aligned_head_dim;
1477 if (K != aligned_embed_dim) {
1480 const float *proj_in = attn_out;
1482 if (!proj_scratch) {
1485 for (
int t = 0; t < num_tokens; ++t) {
1486 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1487 for (
int h = 0; h < H; ++h) {
1488 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1489 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1491 (
size_t)aligned_head_dim *
sizeof(
float));
1494 proj_in = proj_scratch;
1496 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1512 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1513 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1514 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1526 int aligned_embed_dim,
1527 int aligned_head_dim,
1528 int aligned_intermediate_dim,
1529 int aligned_context_window
1531 const MODELLayerOffsets *L = &MODEL_LAYERS[8];
1533 float *input = MODEL_PTR(model, MODEL_LAYERS[7].output);
1534 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1535 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1536 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1537 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1538 float *q = MODEL_PTR(model, L->q);
1539 float *k = MODEL_PTR(model, L->k);
1540 float *v = MODEL_PTR(model, L->v);
1541 float *attn_out = MODEL_PTR(model, L->attn_out);
1542 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1543 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1544 float *residual1 = MODEL_PTR(model, L->residual1);
1545 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1546 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1547 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1548 float *output = MODEL_PTR(model, L->output);
1550 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1551 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1552 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1553 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1554 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1555 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1556 const float *BQ = NULL;
1557 const float *BK = NULL;
1558 const float *BV = NULL;
1559 const float *BO = NULL;
1560 const float *B1 = NULL;
1561 const float *B2 = NULL;
1563 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1564 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1569 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1570 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1571 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1585 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1586 for (
int h = 0; h < H; ++h) {
1587 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1588 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1589 float *q_h = q + (size_t)h * q_head_stride;
1590 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1595 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1596 for (
int h = 0; h < H_kv; ++h) {
1597 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1598 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1599 float *k_h = k + (size_t)h * kv_head_stride;
1600 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1605 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1606 for (
int h = 0; h < H_kv; ++h) {
1607 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1608 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1609 float *v_h = v + (size_t)h * kv_head_stride;
1610 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1625 aligned_context_window);
1637 aligned_context_window);
1640 const int K = H * aligned_head_dim;
1641 if (K != aligned_embed_dim) {
1644 const float *proj_in = attn_out;
1646 if (!proj_scratch) {
1649 for (
int t = 0; t < num_tokens; ++t) {
1650 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1651 for (
int h = 0; h < H; ++h) {
1652 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1653 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1655 (
size_t)aligned_head_dim *
sizeof(
float));
1658 proj_in = proj_scratch;
1660 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1676 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1677 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1678 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1690 int aligned_embed_dim,
1691 int aligned_head_dim,
1692 int aligned_intermediate_dim,
1693 int aligned_context_window
1695 const MODELLayerOffsets *L = &MODEL_LAYERS[9];
1697 float *input = MODEL_PTR(model, MODEL_LAYERS[8].output);
1698 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1699 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1700 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1701 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1702 float *q = MODEL_PTR(model, L->q);
1703 float *k = MODEL_PTR(model, L->k);
1704 float *v = MODEL_PTR(model, L->v);
1705 float *attn_out = MODEL_PTR(model, L->attn_out);
1706 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1707 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1708 float *residual1 = MODEL_PTR(model, L->residual1);
1709 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1710 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1711 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1712 float *output = MODEL_PTR(model, L->output);
1714 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1715 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1716 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1717 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1718 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1719 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1720 const float *BQ = NULL;
1721 const float *BK = NULL;
1722 const float *BV = NULL;
1723 const float *BO = NULL;
1724 const float *B1 = NULL;
1725 const float *B2 = NULL;
1727 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1728 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1733 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1734 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1735 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1749 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1750 for (
int h = 0; h < H; ++h) {
1751 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1752 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1753 float *q_h = q + (size_t)h * q_head_stride;
1754 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1759 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1760 for (
int h = 0; h < H_kv; ++h) {
1761 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1762 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1763 float *k_h = k + (size_t)h * kv_head_stride;
1764 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1769 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1770 for (
int h = 0; h < H_kv; ++h) {
1771 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1772 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1773 float *v_h = v + (size_t)h * kv_head_stride;
1774 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1789 aligned_context_window);
1801 aligned_context_window);
1804 const int K = H * aligned_head_dim;
1805 if (K != aligned_embed_dim) {
1808 const float *proj_in = attn_out;
1810 if (!proj_scratch) {
1813 for (
int t = 0; t < num_tokens; ++t) {
1814 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1815 for (
int h = 0; h < H; ++h) {
1816 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1817 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1819 (
size_t)aligned_head_dim *
sizeof(
float));
1822 proj_in = proj_scratch;
1824 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1840 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1841 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1842 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1854 int aligned_embed_dim,
1855 int aligned_head_dim,
1856 int aligned_intermediate_dim,
1857 int aligned_context_window
1859 const MODELLayerOffsets *L = &MODEL_LAYERS[10];
1861 float *input = MODEL_PTR(model, MODEL_LAYERS[9].output);
1862 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1863 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1864 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1865 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1866 float *q = MODEL_PTR(model, L->q);
1867 float *k = MODEL_PTR(model, L->k);
1868 float *v = MODEL_PTR(model, L->v);
1869 float *attn_out = MODEL_PTR(model, L->attn_out);
1870 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1871 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1872 float *residual1 = MODEL_PTR(model, L->residual1);
1873 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1874 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1875 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1876 float *output = MODEL_PTR(model, L->output);
1878 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1879 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1880 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1881 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1882 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1883 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1884 const float *BQ = NULL;
1885 const float *BK = NULL;
1886 const float *BV = NULL;
1887 const float *BO = NULL;
1888 const float *B1 = NULL;
1889 const float *B2 = NULL;
1891 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1892 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1897 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1898 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1899 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1913 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1914 for (
int h = 0; h < H; ++h) {
1915 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1916 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1917 float *q_h = q + (size_t)h * q_head_stride;
1918 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1923 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1924 for (
int h = 0; h < H_kv; ++h) {
1925 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1926 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1927 float *k_h = k + (size_t)h * kv_head_stride;
1928 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1933 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1934 for (
int h = 0; h < H_kv; ++h) {
1935 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1936 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1937 float *v_h = v + (size_t)h * kv_head_stride;
1938 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1953 aligned_context_window);
1965 aligned_context_window);
1968 const int K = H * aligned_head_dim;
1969 if (K != aligned_embed_dim) {
1972 const float *proj_in = attn_out;
1974 if (!proj_scratch) {
1977 for (
int t = 0; t < num_tokens; ++t) {
1978 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1979 for (
int h = 0; h < H; ++h) {
1980 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1981 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1983 (
size_t)aligned_head_dim *
sizeof(
float));
1986 proj_in = proj_scratch;
1988 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2004 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2005 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2006 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2018 int aligned_embed_dim,
2019 int aligned_head_dim,
2020 int aligned_intermediate_dim,
2021 int aligned_context_window
2023 const MODELLayerOffsets *L = &MODEL_LAYERS[11];
2025 float *input = MODEL_PTR(model, MODEL_LAYERS[10].output);
2026 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2027 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2028 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2029 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2030 float *q = MODEL_PTR(model, L->q);
2031 float *k = MODEL_PTR(model, L->k);
2032 float *v = MODEL_PTR(model, L->v);
2033 float *attn_out = MODEL_PTR(model, L->attn_out);
2034 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2035 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2036 float *residual1 = MODEL_PTR(model, L->residual1);
2037 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2038 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2039 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2040 float *output = MODEL_PTR(model, L->output);
2042 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2043 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2044 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2045 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2046 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2047 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2048 const float *BQ = NULL;
2049 const float *BK = NULL;
2050 const float *BV = NULL;
2051 const float *BO = NULL;
2052 const float *B1 = NULL;
2053 const float *B2 = NULL;
2055 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2056 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2061 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2062 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2063 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2077 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2078 for (
int h = 0; h < H; ++h) {
2079 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2080 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2081 float *q_h = q + (size_t)h * q_head_stride;
2082 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2087 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2088 for (
int h = 0; h < H_kv; ++h) {
2089 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2090 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2091 float *k_h = k + (size_t)h * kv_head_stride;
2092 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2097 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2098 for (
int h = 0; h < H_kv; ++h) {
2099 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2100 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2101 float *v_h = v + (size_t)h * kv_head_stride;
2102 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2117 aligned_context_window);
2129 aligned_context_window);
2132 const int K = H * aligned_head_dim;
2133 if (K != aligned_embed_dim) {
2136 const float *proj_in = attn_out;
2138 if (!proj_scratch) {
2141 for (
int t = 0; t < num_tokens; ++t) {
2142 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2143 for (
int h = 0; h < H; ++h) {
2144 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2145 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2147 (
size_t)aligned_head_dim *
sizeof(
float));
2150 proj_in = proj_scratch;
2152 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2168 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2169 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2170 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2182 int aligned_embed_dim,
2183 int aligned_head_dim,
2184 int aligned_intermediate_dim,
2185 int aligned_context_window
2187 const MODELLayerOffsets *L = &MODEL_LAYERS[12];
2189 float *input = MODEL_PTR(model, MODEL_LAYERS[11].output);
2190 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2191 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2192 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2193 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2194 float *q = MODEL_PTR(model, L->q);
2195 float *k = MODEL_PTR(model, L->k);
2196 float *v = MODEL_PTR(model, L->v);
2197 float *attn_out = MODEL_PTR(model, L->attn_out);
2198 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2199 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2200 float *residual1 = MODEL_PTR(model, L->residual1);
2201 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2202 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2203 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2204 float *output = MODEL_PTR(model, L->output);
2206 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2207 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2208 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2209 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2210 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2211 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2212 const float *BQ = NULL;
2213 const float *BK = NULL;
2214 const float *BV = NULL;
2215 const float *BO = NULL;
2216 const float *B1 = NULL;
2217 const float *B2 = NULL;
2219 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2220 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2225 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2226 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2227 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2241 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2242 for (
int h = 0; h < H; ++h) {
2243 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2244 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2245 float *q_h = q + (size_t)h * q_head_stride;
2246 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2251 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2252 for (
int h = 0; h < H_kv; ++h) {
2253 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2254 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2255 float *k_h = k + (size_t)h * kv_head_stride;
2256 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2261 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2262 for (
int h = 0; h < H_kv; ++h) {
2263 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2264 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2265 float *v_h = v + (size_t)h * kv_head_stride;
2266 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2281 aligned_context_window);
2293 aligned_context_window);
2296 const int K = H * aligned_head_dim;
2297 if (K != aligned_embed_dim) {
2300 const float *proj_in = attn_out;
2302 if (!proj_scratch) {
2305 for (
int t = 0; t < num_tokens; ++t) {
2306 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2307 for (
int h = 0; h < H; ++h) {
2308 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2309 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2311 (
size_t)aligned_head_dim *
sizeof(
float));
2314 proj_in = proj_scratch;
2316 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2332 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2333 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2334 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2346 int aligned_embed_dim,
2347 int aligned_head_dim,
2348 int aligned_intermediate_dim,
2349 int aligned_context_window
2351 const MODELLayerOffsets *L = &MODEL_LAYERS[13];
2353 float *input = MODEL_PTR(model, MODEL_LAYERS[12].output);
2354 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2355 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2356 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2357 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2358 float *q = MODEL_PTR(model, L->q);
2359 float *k = MODEL_PTR(model, L->k);
2360 float *v = MODEL_PTR(model, L->v);
2361 float *attn_out = MODEL_PTR(model, L->attn_out);
2362 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2363 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2364 float *residual1 = MODEL_PTR(model, L->residual1);
2365 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2366 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2367 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2368 float *output = MODEL_PTR(model, L->output);
2370 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2371 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2372 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2373 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2374 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2375 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2376 const float *BQ = NULL;
2377 const float *BK = NULL;
2378 const float *BV = NULL;
2379 const float *BO = NULL;
2380 const float *B1 = NULL;
2381 const float *B2 = NULL;
2383 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2384 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2389 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2390 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2391 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2405 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2406 for (
int h = 0; h < H; ++h) {
2407 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2408 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2409 float *q_h = q + (size_t)h * q_head_stride;
2410 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2415 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2416 for (
int h = 0; h < H_kv; ++h) {
2417 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2418 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2419 float *k_h = k + (size_t)h * kv_head_stride;
2420 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2425 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2426 for (
int h = 0; h < H_kv; ++h) {
2427 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2428 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2429 float *v_h = v + (size_t)h * kv_head_stride;
2430 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2445 aligned_context_window);
2457 aligned_context_window);
2460 const int K = H * aligned_head_dim;
2461 if (K != aligned_embed_dim) {
2464 const float *proj_in = attn_out;
2466 if (!proj_scratch) {
2469 for (
int t = 0; t < num_tokens; ++t) {
2470 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2471 for (
int h = 0; h < H; ++h) {
2472 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2473 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2475 (
size_t)aligned_head_dim *
sizeof(
float));
2478 proj_in = proj_scratch;
2480 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2496 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2497 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2498 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2510 int aligned_embed_dim,
2511 int aligned_head_dim,
2512 int aligned_intermediate_dim,
2513 int aligned_context_window
2515 const MODELLayerOffsets *L = &MODEL_LAYERS[14];
2517 float *input = MODEL_PTR(model, MODEL_LAYERS[13].output);
2518 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2519 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2520 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2521 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2522 float *q = MODEL_PTR(model, L->q);
2523 float *k = MODEL_PTR(model, L->k);
2524 float *v = MODEL_PTR(model, L->v);
2525 float *attn_out = MODEL_PTR(model, L->attn_out);
2526 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2527 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2528 float *residual1 = MODEL_PTR(model, L->residual1);
2529 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2530 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2531 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2532 float *output = MODEL_PTR(model, L->output);
2534 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2535 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2536 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2537 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2538 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2539 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2540 const float *BQ = NULL;
2541 const float *BK = NULL;
2542 const float *BV = NULL;
2543 const float *BO = NULL;
2544 const float *B1 = NULL;
2545 const float *B2 = NULL;
2547 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2548 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2553 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2554 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2555 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2569 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2570 for (
int h = 0; h < H; ++h) {
2571 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2572 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2573 float *q_h = q + (size_t)h * q_head_stride;
2574 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2579 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2580 for (
int h = 0; h < H_kv; ++h) {
2581 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2582 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2583 float *k_h = k + (size_t)h * kv_head_stride;
2584 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2589 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2590 for (
int h = 0; h < H_kv; ++h) {
2591 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2592 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2593 float *v_h = v + (size_t)h * kv_head_stride;
2594 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2609 aligned_context_window);
2621 aligned_context_window);
2624 const int K = H * aligned_head_dim;
2625 if (K != aligned_embed_dim) {
2628 const float *proj_in = attn_out;
2630 if (!proj_scratch) {
2633 for (
int t = 0; t < num_tokens; ++t) {
2634 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2635 for (
int h = 0; h < H; ++h) {
2636 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2637 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2639 (
size_t)aligned_head_dim *
sizeof(
float));
2642 proj_in = proj_scratch;
2644 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2660 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2661 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2662 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2674 int aligned_embed_dim,
2675 int aligned_head_dim,
2676 int aligned_intermediate_dim,
2677 int aligned_context_window
2679 const MODELLayerOffsets *L = &MODEL_LAYERS[15];
2681 float *input = MODEL_PTR(model, MODEL_LAYERS[14].output);
2682 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2683 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2684 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2685 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2686 float *q = MODEL_PTR(model, L->q);
2687 float *k = MODEL_PTR(model, L->k);
2688 float *v = MODEL_PTR(model, L->v);
2689 float *attn_out = MODEL_PTR(model, L->attn_out);
2690 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2691 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2692 float *residual1 = MODEL_PTR(model, L->residual1);
2693 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2694 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2695 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2696 float *output = MODEL_PTR(model, L->output);
2698 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2699 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2700 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2701 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2702 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2703 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2704 const float *BQ = NULL;
2705 const float *BK = NULL;
2706 const float *BV = NULL;
2707 const float *BO = NULL;
2708 const float *B1 = NULL;
2709 const float *B2 = NULL;
2711 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2712 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2717 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2718 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2719 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2733 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2734 for (
int h = 0; h < H; ++h) {
2735 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2736 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2737 float *q_h = q + (size_t)h * q_head_stride;
2738 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2743 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2744 for (
int h = 0; h < H_kv; ++h) {
2745 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2746 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2747 float *k_h = k + (size_t)h * kv_head_stride;
2748 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2753 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2754 for (
int h = 0; h < H_kv; ++h) {
2755 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2756 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2757 float *v_h = v + (size_t)h * kv_head_stride;
2758 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2773 aligned_context_window);
2785 aligned_context_window);
2788 const int K = H * aligned_head_dim;
2789 if (K != aligned_embed_dim) {
2792 const float *proj_in = attn_out;
2794 if (!proj_scratch) {
2797 for (
int t = 0; t < num_tokens; ++t) {
2798 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2799 for (
int h = 0; h < H; ++h) {
2800 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2801 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2803 (
size_t)aligned_head_dim *
sizeof(
float));
2806 proj_in = proj_scratch;
2808 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2824 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2825 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2826 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2838 int aligned_embed_dim,
2839 int aligned_head_dim,
2840 int aligned_intermediate_dim,
2841 int aligned_context_window
2843 const MODELLayerOffsets *L = &MODEL_LAYERS[16];
2845 float *input = MODEL_PTR(model, MODEL_LAYERS[15].output);
2846 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2847 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2848 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2849 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2850 float *q = MODEL_PTR(model, L->q);
2851 float *k = MODEL_PTR(model, L->k);
2852 float *v = MODEL_PTR(model, L->v);
2853 float *attn_out = MODEL_PTR(model, L->attn_out);
2854 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2855 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2856 float *residual1 = MODEL_PTR(model, L->residual1);
2857 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2858 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2859 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2860 float *output = MODEL_PTR(model, L->output);
2862 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2863 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2864 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2865 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2866 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2867 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2868 const float *BQ = NULL;
2869 const float *BK = NULL;
2870 const float *BV = NULL;
2871 const float *BO = NULL;
2872 const float *B1 = NULL;
2873 const float *B2 = NULL;
2875 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2876 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2881 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2882 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2883 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2897 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2898 for (
int h = 0; h < H; ++h) {
2899 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2900 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2901 float *q_h = q + (size_t)h * q_head_stride;
2902 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2907 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2908 for (
int h = 0; h < H_kv; ++h) {
2909 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2910 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2911 float *k_h = k + (size_t)h * kv_head_stride;
2912 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2917 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2918 for (
int h = 0; h < H_kv; ++h) {
2919 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2920 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2921 float *v_h = v + (size_t)h * kv_head_stride;
2922 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2937 aligned_context_window);
2949 aligned_context_window);
2952 const int K = H * aligned_head_dim;
2953 if (K != aligned_embed_dim) {
2956 const float *proj_in = attn_out;
2958 if (!proj_scratch) {
2961 for (
int t = 0; t < num_tokens; ++t) {
2962 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2963 for (
int h = 0; h < H; ++h) {
2964 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2965 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2967 (
size_t)aligned_head_dim *
sizeof(
float));
2970 proj_in = proj_scratch;
2972 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2988 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2989 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2990 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3002 int aligned_embed_dim,
3003 int aligned_head_dim,
3004 int aligned_intermediate_dim,
3005 int aligned_context_window
3007 const MODELLayerOffsets *L = &MODEL_LAYERS[17];
3009 float *input = MODEL_PTR(model, MODEL_LAYERS[16].output);
3010 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3011 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3012 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3013 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3014 float *q = MODEL_PTR(model, L->q);
3015 float *k = MODEL_PTR(model, L->k);
3016 float *v = MODEL_PTR(model, L->v);
3017 float *attn_out = MODEL_PTR(model, L->attn_out);
3018 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3019 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3020 float *residual1 = MODEL_PTR(model, L->residual1);
3021 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3022 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3023 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3024 float *output = MODEL_PTR(model, L->output);
3026 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3027 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3028 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3029 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3030 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3031 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3032 const float *BQ = NULL;
3033 const float *BK = NULL;
3034 const float *BV = NULL;
3035 const float *BO = NULL;
3036 const float *B1 = NULL;
3037 const float *B2 = NULL;
3039 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3040 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3045 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3046 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3047 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3061 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3062 for (
int h = 0; h < H; ++h) {
3063 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3064 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3065 float *q_h = q + (size_t)h * q_head_stride;
3066 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3071 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3072 for (
int h = 0; h < H_kv; ++h) {
3073 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3074 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3075 float *k_h = k + (size_t)h * kv_head_stride;
3076 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3081 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3082 for (
int h = 0; h < H_kv; ++h) {
3083 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3084 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3085 float *v_h = v + (size_t)h * kv_head_stride;
3086 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3101 aligned_context_window);
3113 aligned_context_window);
3116 const int K = H * aligned_head_dim;
3117 if (K != aligned_embed_dim) {
3120 const float *proj_in = attn_out;
3122 if (!proj_scratch) {
3125 for (
int t = 0; t < num_tokens; ++t) {
3126 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3127 for (
int h = 0; h < H; ++h) {
3128 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3129 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3131 (
size_t)aligned_head_dim *
sizeof(
float));
3134 proj_in = proj_scratch;
3136 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3152 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3153 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3154 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3166 int aligned_embed_dim,
3167 int aligned_head_dim,
3168 int aligned_intermediate_dim,
3169 int aligned_context_window
3171 const MODELLayerOffsets *L = &MODEL_LAYERS[18];
3173 float *input = MODEL_PTR(model, MODEL_LAYERS[17].output);
3174 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3175 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3176 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3177 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3178 float *q = MODEL_PTR(model, L->q);
3179 float *k = MODEL_PTR(model, L->k);
3180 float *v = MODEL_PTR(model, L->v);
3181 float *attn_out = MODEL_PTR(model, L->attn_out);
3182 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3183 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3184 float *residual1 = MODEL_PTR(model, L->residual1);
3185 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3186 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3187 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3188 float *output = MODEL_PTR(model, L->output);
3190 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3191 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3192 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3193 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3194 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3195 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3196 const float *BQ = NULL;
3197 const float *BK = NULL;
3198 const float *BV = NULL;
3199 const float *BO = NULL;
3200 const float *B1 = NULL;
3201 const float *B2 = NULL;
3203 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3204 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3209 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3210 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3211 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3225 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3226 for (
int h = 0; h < H; ++h) {
3227 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3228 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3229 float *q_h = q + (size_t)h * q_head_stride;
3230 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3235 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3236 for (
int h = 0; h < H_kv; ++h) {
3237 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3238 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3239 float *k_h = k + (size_t)h * kv_head_stride;
3240 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3245 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3246 for (
int h = 0; h < H_kv; ++h) {
3247 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3248 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3249 float *v_h = v + (size_t)h * kv_head_stride;
3250 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3265 aligned_context_window);
3277 aligned_context_window);
3280 const int K = H * aligned_head_dim;
3281 if (K != aligned_embed_dim) {
3284 const float *proj_in = attn_out;
3286 if (!proj_scratch) {
3289 for (
int t = 0; t < num_tokens; ++t) {
3290 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3291 for (
int h = 0; h < H; ++h) {
3292 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3293 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3295 (
size_t)aligned_head_dim *
sizeof(
float));
3298 proj_in = proj_scratch;
3300 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3316 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3317 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3318 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3330 int aligned_embed_dim,
3331 int aligned_head_dim,
3332 int aligned_intermediate_dim,
3333 int aligned_context_window
3335 const MODELLayerOffsets *L = &MODEL_LAYERS[19];
3337 float *input = MODEL_PTR(model, MODEL_LAYERS[18].output);
3338 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3339 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3340 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3341 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3342 float *q = MODEL_PTR(model, L->q);
3343 float *k = MODEL_PTR(model, L->k);
3344 float *v = MODEL_PTR(model, L->v);
3345 float *attn_out = MODEL_PTR(model, L->attn_out);
3346 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3347 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3348 float *residual1 = MODEL_PTR(model, L->residual1);
3349 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3350 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3351 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3352 float *output = MODEL_PTR(model, L->output);
3354 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3355 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3356 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3357 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3358 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3359 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3360 const float *BQ = NULL;
3361 const float *BK = NULL;
3362 const float *BV = NULL;
3363 const float *BO = NULL;
3364 const float *B1 = NULL;
3365 const float *B2 = NULL;
3367 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3368 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3373 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3374 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3375 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3389 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3390 for (
int h = 0; h < H; ++h) {
3391 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3392 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3393 float *q_h = q + (size_t)h * q_head_stride;
3394 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3399 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3400 for (
int h = 0; h < H_kv; ++h) {
3401 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3402 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3403 float *k_h = k + (size_t)h * kv_head_stride;
3404 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3409 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3410 for (
int h = 0; h < H_kv; ++h) {
3411 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3412 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3413 float *v_h = v + (size_t)h * kv_head_stride;
3414 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3429 aligned_context_window);
3441 aligned_context_window);
3444 const int K = H * aligned_head_dim;
3445 if (K != aligned_embed_dim) {
3448 const float *proj_in = attn_out;
3450 if (!proj_scratch) {
3453 for (
int t = 0; t < num_tokens; ++t) {
3454 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3455 for (
int h = 0; h < H; ++h) {
3456 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3457 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3459 (
size_t)aligned_head_dim *
sizeof(
float));
3462 proj_in = proj_scratch;
3464 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3480 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3481 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3482 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3494 int aligned_embed_dim,
3495 int aligned_head_dim,
3496 int aligned_intermediate_dim,
3497 int aligned_context_window
3499 const MODELLayerOffsets *L = &MODEL_LAYERS[20];
3501 float *input = MODEL_PTR(model, MODEL_LAYERS[19].output);
3502 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3503 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3504 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3505 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3506 float *q = MODEL_PTR(model, L->q);
3507 float *k = MODEL_PTR(model, L->k);
3508 float *v = MODEL_PTR(model, L->v);
3509 float *attn_out = MODEL_PTR(model, L->attn_out);
3510 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3511 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3512 float *residual1 = MODEL_PTR(model, L->residual1);
3513 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3514 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3515 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3516 float *output = MODEL_PTR(model, L->output);
3518 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3519 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3520 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3521 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3522 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3523 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3524 const float *BQ = NULL;
3525 const float *BK = NULL;
3526 const float *BV = NULL;
3527 const float *BO = NULL;
3528 const float *B1 = NULL;
3529 const float *B2 = NULL;
3531 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3532 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3537 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3538 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3539 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3553 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3554 for (
int h = 0; h < H; ++h) {
3555 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3556 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3557 float *q_h = q + (size_t)h * q_head_stride;
3558 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3563 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3564 for (
int h = 0; h < H_kv; ++h) {
3565 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3566 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3567 float *k_h = k + (size_t)h * kv_head_stride;
3568 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3573 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3574 for (
int h = 0; h < H_kv; ++h) {
3575 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3576 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3577 float *v_h = v + (size_t)h * kv_head_stride;
3578 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3593 aligned_context_window);
3605 aligned_context_window);
3608 const int K = H * aligned_head_dim;
3609 if (K != aligned_embed_dim) {
3612 const float *proj_in = attn_out;
3614 if (!proj_scratch) {
3617 for (
int t = 0; t < num_tokens; ++t) {
3618 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3619 for (
int h = 0; h < H; ++h) {
3620 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3621 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3623 (
size_t)aligned_head_dim *
sizeof(
float));
3626 proj_in = proj_scratch;
3628 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3644 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3645 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3646 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3658 int aligned_embed_dim,
3659 int aligned_head_dim,
3660 int aligned_intermediate_dim,
3661 int aligned_context_window
3663 const MODELLayerOffsets *L = &MODEL_LAYERS[21];
3665 float *input = MODEL_PTR(model, MODEL_LAYERS[20].output);
3666 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3667 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3668 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3669 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3670 float *q = MODEL_PTR(model, L->q);
3671 float *k = MODEL_PTR(model, L->k);
3672 float *v = MODEL_PTR(model, L->v);
3673 float *attn_out = MODEL_PTR(model, L->attn_out);
3674 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3675 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3676 float *residual1 = MODEL_PTR(model, L->residual1);
3677 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3678 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3679 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3680 float *output = MODEL_PTR(model, L->output);
3682 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3683 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3684 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3685 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3686 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3687 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3688 const float *BQ = NULL;
3689 const float *BK = NULL;
3690 const float *BV = NULL;
3691 const float *BO = NULL;
3692 const float *B1 = NULL;
3693 const float *B2 = NULL;
3695 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3696 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3701 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3702 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3703 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3717 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3718 for (
int h = 0; h < H; ++h) {
3719 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3720 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3721 float *q_h = q + (size_t)h * q_head_stride;
3722 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3727 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3728 for (
int h = 0; h < H_kv; ++h) {
3729 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3730 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3731 float *k_h = k + (size_t)h * kv_head_stride;
3732 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3737 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3738 for (
int h = 0; h < H_kv; ++h) {
3739 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3740 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3741 float *v_h = v + (size_t)h * kv_head_stride;
3742 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3757 aligned_context_window);
3769 aligned_context_window);
3772 const int K = H * aligned_head_dim;
3773 if (K != aligned_embed_dim) {
3776 const float *proj_in = attn_out;
3778 if (!proj_scratch) {
3781 for (
int t = 0; t < num_tokens; ++t) {
3782 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3783 for (
int h = 0; h < H; ++h) {
3784 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3785 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3787 (
size_t)aligned_head_dim *
sizeof(
float));
3790 proj_in = proj_scratch;
3792 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3808 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3809 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3810 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3822 int aligned_embed_dim,
3823 int aligned_head_dim,
3824 int aligned_intermediate_dim,
3825 int aligned_context_window
3827 const MODELLayerOffsets *L = &MODEL_LAYERS[22];
3829 float *input = MODEL_PTR(model, MODEL_LAYERS[21].output);
3830 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3831 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3832 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3833 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3834 float *q = MODEL_PTR(model, L->q);
3835 float *k = MODEL_PTR(model, L->k);
3836 float *v = MODEL_PTR(model, L->v);
3837 float *attn_out = MODEL_PTR(model, L->attn_out);
3838 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3839 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3840 float *residual1 = MODEL_PTR(model, L->residual1);
3841 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3842 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3843 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3844 float *output = MODEL_PTR(model, L->output);
3846 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3847 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3848 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3849 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3850 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3851 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3852 const float *BQ = NULL;
3853 const float *BK = NULL;
3854 const float *BV = NULL;
3855 const float *BO = NULL;
3856 const float *B1 = NULL;
3857 const float *B2 = NULL;
3859 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3860 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3865 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3866 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3867 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3881 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3882 for (
int h = 0; h < H; ++h) {
3883 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3884 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3885 float *q_h = q + (size_t)h * q_head_stride;
3886 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3891 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3892 for (
int h = 0; h < H_kv; ++h) {
3893 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3894 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3895 float *k_h = k + (size_t)h * kv_head_stride;
3896 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3901 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3902 for (
int h = 0; h < H_kv; ++h) {
3903 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3904 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3905 float *v_h = v + (size_t)h * kv_head_stride;
3906 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3921 aligned_context_window);
3933 aligned_context_window);
3936 const int K = H * aligned_head_dim;
3937 if (K != aligned_embed_dim) {
3940 const float *proj_in = attn_out;
3942 if (!proj_scratch) {
3945 for (
int t = 0; t < num_tokens; ++t) {
3946 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3947 for (
int h = 0; h < H; ++h) {
3948 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3949 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3951 (
size_t)aligned_head_dim *
sizeof(
float));
3954 proj_in = proj_scratch;
3956 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3972 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3973 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3974 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3986 int aligned_embed_dim,
3987 int aligned_head_dim,
3988 int aligned_intermediate_dim,
3989 int aligned_context_window
3991 const MODELLayerOffsets *L = &MODEL_LAYERS[23];
3993 float *input = MODEL_PTR(model, MODEL_LAYERS[22].output);
3994 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3995 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3996 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3997 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3998 float *q = MODEL_PTR(model, L->q);
3999 float *k = MODEL_PTR(model, L->k);
4000 float *v = MODEL_PTR(model, L->v);
4001 float *attn_out = MODEL_PTR(model, L->attn_out);
4002 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4003 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4004 float *residual1 = MODEL_PTR(model, L->residual1);
4005 float *fc1_out = MODEL_PTR(model, L->fc1_out);
4006 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
4007 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4008 float *output = MODEL_PTR(model, L->output);
4010 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4011 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4012 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4013 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4014 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4015 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4016 const float *BQ = NULL;
4017 const float *BK = NULL;
4018 const float *BV = NULL;
4019 const float *BO = NULL;
4020 const float *B1 = NULL;
4021 const float *B2 = NULL;
4023 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4024 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4029 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4030 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
4031 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4045 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
4046 for (
int h = 0; h < H; ++h) {
4047 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
4048 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4049 float *q_h = q + (size_t)h * q_head_stride;
4050 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4055 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4056 for (
int h = 0; h < H_kv; ++h) {
4057 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4058 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4059 float *k_h = k + (size_t)h * kv_head_stride;
4060 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4065 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4066 for (
int h = 0; h < H_kv; ++h) {
4067 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4068 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4069 float *v_h = v + (size_t)h * kv_head_stride;
4070 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4085 aligned_context_window);
4097 aligned_context_window);
4100 const int K = H * aligned_head_dim;
4101 if (K != aligned_embed_dim) {
4104 const float *proj_in = attn_out;
4106 if (!proj_scratch) {
4109 for (
int t = 0; t < num_tokens; ++t) {
4110 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
4111 for (
int h = 0; h < H; ++h) {
4112 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
4113 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
4115 (
size_t)aligned_head_dim *
sizeof(
float));
4118 proj_in = proj_scratch;
4120 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
4136 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
4137 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
4138 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
4153 if (!model || !tokens || num_tokens <= 0) {
4157 const int elem_bytes = MODEL_DTYPE_BYTES;
4158 const int aligned_embed_dim = 1024;
4159 const int aligned_head_dim = 64;
4160 const int aligned_intermediate_dim = 4864;
4161 const int aligned_context_window = 131072;
4163 float *embed_out = MODEL_PTR(model, MODEL_HEADER.embedded_input);
4164 const void *embed_weight = (
const void *)MODEL_PTR(model, MODEL_HEADER.token_emb);
4181 aligned_intermediate_dim,
4182 aligned_context_window);
4189 aligned_intermediate_dim,
4190 aligned_context_window);
4197 aligned_intermediate_dim,
4198 aligned_context_window);
4205 aligned_intermediate_dim,
4206 aligned_context_window);
4213 aligned_intermediate_dim,
4214 aligned_context_window);
4221 aligned_intermediate_dim,
4222 aligned_context_window);
4229 aligned_intermediate_dim,
4230 aligned_context_window);
4237 aligned_intermediate_dim,
4238 aligned_context_window);
4245 aligned_intermediate_dim,
4246 aligned_context_window);
4253 aligned_intermediate_dim,
4254 aligned_context_window);
4261 aligned_intermediate_dim,
4262 aligned_context_window);
4269 aligned_intermediate_dim,
4270 aligned_context_window);
4277 aligned_intermediate_dim,
4278 aligned_context_window);
4285 aligned_intermediate_dim,
4286 aligned_context_window);
4293 aligned_intermediate_dim,
4294 aligned_context_window);
4301 aligned_intermediate_dim,
4302 aligned_context_window);
4309 aligned_intermediate_dim,
4310 aligned_context_window);
4317 aligned_intermediate_dim,
4318 aligned_context_window);
4325 aligned_intermediate_dim,
4326 aligned_context_window);
4333 aligned_intermediate_dim,
4334 aligned_context_window);
4341 aligned_intermediate_dim,
4342 aligned_context_window);
4349 aligned_intermediate_dim,
4350 aligned_context_window);
4357 aligned_intermediate_dim,
4358 aligned_context_window);
4365 aligned_intermediate_dim,
4366 aligned_context_window);
4368 float *last_hidden = MODEL_PTR(model, MODEL_LAYERS[
MODEL_NUM_LAYERS - 1].output);
4369 float *final_ln_weight = MODEL_PTR(model, MODEL_FOOTER.final_ln_weight);
4370 float *final_out = MODEL_PTR(model, MODEL_FOOTER.final_output);
4380 float *logits = MODEL_PTR(model, MODEL_FOOTER.logits);
4381 const void *lm_head = (
const void *)MODEL_PTR(model, MODEL_FOOTER.lm_head_weight);
4383 for (
int t = 0; t < num_tokens; ++t) {
4384 uint8_t q8_buf[q8_bytes];
4385 const float *row = final_out + (size_t)t * (
size_t)aligned_embed_dim;
4408 int aligned_embed_dim,
4409 int aligned_head_dim,
4410 int aligned_intermediate_dim,
4411 int aligned_context_window
4413 const MODELLayerOffsets *L = &MODEL_LAYERS[0];
4415 float *input = MODEL_PTR(model, MODEL_HEADER.embedded_input);
4417 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4418 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4419 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4420 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4421 float *k_cache = MODEL_PTR(model, L->k);
4422 float *v_cache = MODEL_PTR(model, L->v);
4423 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4424 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4425 float *residual1 = MODEL_PTR(model, L->residual1);
4426 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4427 float *output = MODEL_PTR(model, L->output);
4430 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4431 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4432 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4433 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4434 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4435 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4437 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4438 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4444 float q_token[H * aligned_head_dim];
4445 float k_token[H_kv * aligned_head_dim];
4446 float v_token[H_kv * aligned_head_dim];
4447 float attn_token[H * aligned_head_dim];
4450 float fc1_out[2 * aligned_intermediate_dim];
4451 float swiglu_out[aligned_intermediate_dim];
4463 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4467 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4468 if (aligned_head_dim > head_dim) {
4469 for (
int h = 0; h < H; ++h) {
4470 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4471 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4478 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4480 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4481 for (
int h = 0; h < H_kv; ++h) {
4482 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4483 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4484 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4485 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
4486 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4492 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4494 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4495 for (
int h = 0; h < H_kv; ++h) {
4496 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4497 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4498 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4499 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
4500 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4514 for (
int h = 0; h < H_kv; ++h) {
4515 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4536 aligned_context_window,
4542 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4559 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4562 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4565 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4577 int aligned_embed_dim,
4578 int aligned_head_dim,
4579 int aligned_intermediate_dim,
4580 int aligned_context_window
4582 const MODELLayerOffsets *L = &MODEL_LAYERS[1];
4584 float *input = MODEL_PTR(model, MODEL_LAYERS[0].output);
4586 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4587 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4588 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4589 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4590 float *k_cache = MODEL_PTR(model, L->k);
4591 float *v_cache = MODEL_PTR(model, L->v);
4592 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4593 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4594 float *residual1 = MODEL_PTR(model, L->residual1);
4595 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4596 float *output = MODEL_PTR(model, L->output);
4599 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4600 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4601 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4602 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4603 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4604 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4606 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4607 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4613 float q_token[H * aligned_head_dim];
4614 float k_token[H_kv * aligned_head_dim];
4615 float v_token[H_kv * aligned_head_dim];
4616 float attn_token[H * aligned_head_dim];
4619 float fc1_out[2 * aligned_intermediate_dim];
4620 float swiglu_out[aligned_intermediate_dim];
4632 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4636 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4637 if (aligned_head_dim > head_dim) {
4638 for (
int h = 0; h < H; ++h) {
4639 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4640 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4647 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4649 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4650 for (
int h = 0; h < H_kv; ++h) {
4651 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4652 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4653 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4654 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
4655 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4661 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4663 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4664 for (
int h = 0; h < H_kv; ++h) {
4665 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4666 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4667 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4668 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
4669 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4683 for (
int h = 0; h < H_kv; ++h) {
4684 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4705 aligned_context_window,
4711 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4728 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4731 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4734 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4746 int aligned_embed_dim,
4747 int aligned_head_dim,
4748 int aligned_intermediate_dim,
4749 int aligned_context_window
4751 const MODELLayerOffsets *L = &MODEL_LAYERS[2];
4753 float *input = MODEL_PTR(model, MODEL_LAYERS[1].output);
4755 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4756 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4757 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4758 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4759 float *k_cache = MODEL_PTR(model, L->k);
4760 float *v_cache = MODEL_PTR(model, L->v);
4761 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4762 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4763 float *residual1 = MODEL_PTR(model, L->residual1);
4764 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4765 float *output = MODEL_PTR(model, L->output);
4768 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4769 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4770 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4771 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4772 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4773 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4775 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4776 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4782 float q_token[H * aligned_head_dim];
4783 float k_token[H_kv * aligned_head_dim];
4784 float v_token[H_kv * aligned_head_dim];
4785 float attn_token[H * aligned_head_dim];
4788 float fc1_out[2 * aligned_intermediate_dim];
4789 float swiglu_out[aligned_intermediate_dim];
4801 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4805 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4806 if (aligned_head_dim > head_dim) {
4807 for (
int h = 0; h < H; ++h) {
4808 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4809 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4816 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4818 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4819 for (
int h = 0; h < H_kv; ++h) {
4820 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4821 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4822 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4823 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
4824 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4830 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4832 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4833 for (
int h = 0; h < H_kv; ++h) {
4834 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4835 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4836 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4837 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
4838 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4852 for (
int h = 0; h < H_kv; ++h) {
4853 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4874 aligned_context_window,
4880 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
4897 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
4900 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4903 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
4915 int aligned_embed_dim,
4916 int aligned_head_dim,
4917 int aligned_intermediate_dim,
4918 int aligned_context_window
4920 const MODELLayerOffsets *L = &MODEL_LAYERS[3];
4922 float *input = MODEL_PTR(model, MODEL_LAYERS[2].output);
4924 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4925 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4926 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4927 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4928 float *k_cache = MODEL_PTR(model, L->k);
4929 float *v_cache = MODEL_PTR(model, L->v);
4930 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4931 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4932 float *residual1 = MODEL_PTR(model, L->residual1);
4933 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4934 float *output = MODEL_PTR(model, L->output);
4937 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4938 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4939 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4940 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4941 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4942 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4944 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4945 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4951 float q_token[H * aligned_head_dim];
4952 float k_token[H_kv * aligned_head_dim];
4953 float v_token[H_kv * aligned_head_dim];
4954 float attn_token[H * aligned_head_dim];
4957 float fc1_out[2 * aligned_intermediate_dim];
4958 float swiglu_out[aligned_intermediate_dim];
4970 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4974 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
4975 if (aligned_head_dim > head_dim) {
4976 for (
int h = 0; h < H; ++h) {
4977 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4978 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4985 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4987 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4988 for (
int h = 0; h < H_kv; ++h) {
4989 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4990 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4991 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4992 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
4993 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4999 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5001 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5002 for (
int h = 0; h < H_kv; ++h) {
5003 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5004 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5005 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5006 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
5007 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5021 for (
int h = 0; h < H_kv; ++h) {
5022 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5043 aligned_context_window,
5049 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5066 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5069 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5072 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5084 int aligned_embed_dim,
5085 int aligned_head_dim,
5086 int aligned_intermediate_dim,
5087 int aligned_context_window
5089 const MODELLayerOffsets *L = &MODEL_LAYERS[4];
5091 float *input = MODEL_PTR(model, MODEL_LAYERS[3].output);
5093 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5094 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5095 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5096 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5097 float *k_cache = MODEL_PTR(model, L->k);
5098 float *v_cache = MODEL_PTR(model, L->v);
5099 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5100 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5101 float *residual1 = MODEL_PTR(model, L->residual1);
5102 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5103 float *output = MODEL_PTR(model, L->output);
5106 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5107 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5108 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5109 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5110 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5111 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5113 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5114 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5120 float q_token[H * aligned_head_dim];
5121 float k_token[H_kv * aligned_head_dim];
5122 float v_token[H_kv * aligned_head_dim];
5123 float attn_token[H * aligned_head_dim];
5126 float fc1_out[2 * aligned_intermediate_dim];
5127 float swiglu_out[aligned_intermediate_dim];
5139 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5143 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5144 if (aligned_head_dim > head_dim) {
5145 for (
int h = 0; h < H; ++h) {
5146 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5147 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5154 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5156 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5157 for (
int h = 0; h < H_kv; ++h) {
5158 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5159 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5160 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5161 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
5162 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5168 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5170 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5171 for (
int h = 0; h < H_kv; ++h) {
5172 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5173 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5174 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5175 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
5176 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5190 for (
int h = 0; h < H_kv; ++h) {
5191 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5212 aligned_context_window,
5218 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5235 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5238 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5241 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5253 int aligned_embed_dim,
5254 int aligned_head_dim,
5255 int aligned_intermediate_dim,
5256 int aligned_context_window
5258 const MODELLayerOffsets *L = &MODEL_LAYERS[5];
5260 float *input = MODEL_PTR(model, MODEL_LAYERS[4].output);
5262 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5263 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5264 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5265 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5266 float *k_cache = MODEL_PTR(model, L->k);
5267 float *v_cache = MODEL_PTR(model, L->v);
5268 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5269 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5270 float *residual1 = MODEL_PTR(model, L->residual1);
5271 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5272 float *output = MODEL_PTR(model, L->output);
5275 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5276 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5277 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5278 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5279 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5280 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5282 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5283 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5289 float q_token[H * aligned_head_dim];
5290 float k_token[H_kv * aligned_head_dim];
5291 float v_token[H_kv * aligned_head_dim];
5292 float attn_token[H * aligned_head_dim];
5295 float fc1_out[2 * aligned_intermediate_dim];
5296 float swiglu_out[aligned_intermediate_dim];
5308 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5312 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5313 if (aligned_head_dim > head_dim) {
5314 for (
int h = 0; h < H; ++h) {
5315 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5316 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5323 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5325 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5326 for (
int h = 0; h < H_kv; ++h) {
5327 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5328 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5329 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5330 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
5331 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5337 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5339 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5340 for (
int h = 0; h < H_kv; ++h) {
5341 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5342 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5343 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5344 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
5345 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5359 for (
int h = 0; h < H_kv; ++h) {
5360 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5381 aligned_context_window,
5387 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5404 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5407 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5410 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5422 int aligned_embed_dim,
5423 int aligned_head_dim,
5424 int aligned_intermediate_dim,
5425 int aligned_context_window
5427 const MODELLayerOffsets *L = &MODEL_LAYERS[6];
5429 float *input = MODEL_PTR(model, MODEL_LAYERS[5].output);
5431 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5432 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5433 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5434 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5435 float *k_cache = MODEL_PTR(model, L->k);
5436 float *v_cache = MODEL_PTR(model, L->v);
5437 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5438 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5439 float *residual1 = MODEL_PTR(model, L->residual1);
5440 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5441 float *output = MODEL_PTR(model, L->output);
5444 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5445 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5446 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5447 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5448 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5449 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5451 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5452 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5458 float q_token[H * aligned_head_dim];
5459 float k_token[H_kv * aligned_head_dim];
5460 float v_token[H_kv * aligned_head_dim];
5461 float attn_token[H * aligned_head_dim];
5464 float fc1_out[2 * aligned_intermediate_dim];
5465 float swiglu_out[aligned_intermediate_dim];
5477 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5481 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5482 if (aligned_head_dim > head_dim) {
5483 for (
int h = 0; h < H; ++h) {
5484 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5485 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5492 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5494 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5495 for (
int h = 0; h < H_kv; ++h) {
5496 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5497 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5498 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5499 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
5500 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5506 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5508 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5509 for (
int h = 0; h < H_kv; ++h) {
5510 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5511 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5512 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5513 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
5514 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5528 for (
int h = 0; h < H_kv; ++h) {
5529 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5550 aligned_context_window,
5556 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5573 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5576 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5579 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5591 int aligned_embed_dim,
5592 int aligned_head_dim,
5593 int aligned_intermediate_dim,
5594 int aligned_context_window
5596 const MODELLayerOffsets *L = &MODEL_LAYERS[7];
5598 float *input = MODEL_PTR(model, MODEL_LAYERS[6].output);
5600 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5601 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5602 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5603 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5604 float *k_cache = MODEL_PTR(model, L->k);
5605 float *v_cache = MODEL_PTR(model, L->v);
5606 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5607 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5608 float *residual1 = MODEL_PTR(model, L->residual1);
5609 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5610 float *output = MODEL_PTR(model, L->output);
5613 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5614 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5615 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5616 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5617 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5618 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5620 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5621 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5627 float q_token[H * aligned_head_dim];
5628 float k_token[H_kv * aligned_head_dim];
5629 float v_token[H_kv * aligned_head_dim];
5630 float attn_token[H * aligned_head_dim];
5633 float fc1_out[2 * aligned_intermediate_dim];
5634 float swiglu_out[aligned_intermediate_dim];
5646 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5650 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5651 if (aligned_head_dim > head_dim) {
5652 for (
int h = 0; h < H; ++h) {
5653 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5654 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5661 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5663 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5664 for (
int h = 0; h < H_kv; ++h) {
5665 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5666 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5667 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5668 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
5669 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5675 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5677 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5678 for (
int h = 0; h < H_kv; ++h) {
5679 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5680 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5681 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5682 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
5683 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5697 for (
int h = 0; h < H_kv; ++h) {
5698 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5719 aligned_context_window,
5725 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5742 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5745 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5748 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5760 int aligned_embed_dim,
5761 int aligned_head_dim,
5762 int aligned_intermediate_dim,
5763 int aligned_context_window
5765 const MODELLayerOffsets *L = &MODEL_LAYERS[8];
5767 float *input = MODEL_PTR(model, MODEL_LAYERS[7].output);
5769 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5770 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5771 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5772 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5773 float *k_cache = MODEL_PTR(model, L->k);
5774 float *v_cache = MODEL_PTR(model, L->v);
5775 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5776 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5777 float *residual1 = MODEL_PTR(model, L->residual1);
5778 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5779 float *output = MODEL_PTR(model, L->output);
5782 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5783 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5784 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5785 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5786 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5787 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5789 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5790 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5796 float q_token[H * aligned_head_dim];
5797 float k_token[H_kv * aligned_head_dim];
5798 float v_token[H_kv * aligned_head_dim];
5799 float attn_token[H * aligned_head_dim];
5802 float fc1_out[2 * aligned_intermediate_dim];
5803 float swiglu_out[aligned_intermediate_dim];
5815 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5819 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5820 if (aligned_head_dim > head_dim) {
5821 for (
int h = 0; h < H; ++h) {
5822 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5823 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5830 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5832 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5833 for (
int h = 0; h < H_kv; ++h) {
5834 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5835 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5836 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5837 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
5838 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5844 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5846 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5847 for (
int h = 0; h < H_kv; ++h) {
5848 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5849 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
5850 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5851 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
5852 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5866 for (
int h = 0; h < H_kv; ++h) {
5867 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5888 aligned_context_window,
5894 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
5911 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
5914 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5917 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
5929 int aligned_embed_dim,
5930 int aligned_head_dim,
5931 int aligned_intermediate_dim,
5932 int aligned_context_window
5934 const MODELLayerOffsets *L = &MODEL_LAYERS[9];
5936 float *input = MODEL_PTR(model, MODEL_LAYERS[8].output);
5938 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5939 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5940 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5941 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5942 float *k_cache = MODEL_PTR(model, L->k);
5943 float *v_cache = MODEL_PTR(model, L->v);
5944 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5945 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5946 float *residual1 = MODEL_PTR(model, L->residual1);
5947 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5948 float *output = MODEL_PTR(model, L->output);
5951 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5952 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5953 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5954 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5955 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5956 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5958 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5959 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5965 float q_token[H * aligned_head_dim];
5966 float k_token[H_kv * aligned_head_dim];
5967 float v_token[H_kv * aligned_head_dim];
5968 float attn_token[H * aligned_head_dim];
5971 float fc1_out[2 * aligned_intermediate_dim];
5972 float swiglu_out[aligned_intermediate_dim];
5984 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5988 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
5989 if (aligned_head_dim > head_dim) {
5990 for (
int h = 0; h < H; ++h) {
5991 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5992 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5999 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6001 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6002 for (
int h = 0; h < H_kv; ++h) {
6003 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6004 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6005 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6006 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
6007 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6013 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6015 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6016 for (
int h = 0; h < H_kv; ++h) {
6017 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6018 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6019 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6020 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
6021 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6035 for (
int h = 0; h < H_kv; ++h) {
6036 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6057 aligned_context_window,
6063 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6080 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6083 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6086 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6098 int aligned_embed_dim,
6099 int aligned_head_dim,
6100 int aligned_intermediate_dim,
6101 int aligned_context_window
6103 const MODELLayerOffsets *L = &MODEL_LAYERS[10];
6105 float *input = MODEL_PTR(model, MODEL_LAYERS[9].output);
6107 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6108 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6109 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6110 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6111 float *k_cache = MODEL_PTR(model, L->k);
6112 float *v_cache = MODEL_PTR(model, L->v);
6113 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6114 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6115 float *residual1 = MODEL_PTR(model, L->residual1);
6116 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6117 float *output = MODEL_PTR(model, L->output);
6120 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6121 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6122 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6123 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6124 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6125 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6127 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6128 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6134 float q_token[H * aligned_head_dim];
6135 float k_token[H_kv * aligned_head_dim];
6136 float v_token[H_kv * aligned_head_dim];
6137 float attn_token[H * aligned_head_dim];
6140 float fc1_out[2 * aligned_intermediate_dim];
6141 float swiglu_out[aligned_intermediate_dim];
6153 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6157 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6158 if (aligned_head_dim > head_dim) {
6159 for (
int h = 0; h < H; ++h) {
6160 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6161 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6168 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6170 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6171 for (
int h = 0; h < H_kv; ++h) {
6172 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6173 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6174 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6175 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
6176 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6182 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6184 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6185 for (
int h = 0; h < H_kv; ++h) {
6186 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6187 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6188 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6189 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
6190 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6204 for (
int h = 0; h < H_kv; ++h) {
6205 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6226 aligned_context_window,
6232 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6249 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6252 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6255 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6267 int aligned_embed_dim,
6268 int aligned_head_dim,
6269 int aligned_intermediate_dim,
6270 int aligned_context_window
6272 const MODELLayerOffsets *L = &MODEL_LAYERS[11];
6274 float *input = MODEL_PTR(model, MODEL_LAYERS[10].output);
6276 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6277 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6278 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6279 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6280 float *k_cache = MODEL_PTR(model, L->k);
6281 float *v_cache = MODEL_PTR(model, L->v);
6282 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6283 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6284 float *residual1 = MODEL_PTR(model, L->residual1);
6285 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6286 float *output = MODEL_PTR(model, L->output);
6289 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6290 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6291 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6292 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6293 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6294 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6296 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6297 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6303 float q_token[H * aligned_head_dim];
6304 float k_token[H_kv * aligned_head_dim];
6305 float v_token[H_kv * aligned_head_dim];
6306 float attn_token[H * aligned_head_dim];
6309 float fc1_out[2 * aligned_intermediate_dim];
6310 float swiglu_out[aligned_intermediate_dim];
6322 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6326 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6327 if (aligned_head_dim > head_dim) {
6328 for (
int h = 0; h < H; ++h) {
6329 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6330 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6337 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6339 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6340 for (
int h = 0; h < H_kv; ++h) {
6341 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6342 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6343 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6344 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
6345 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6351 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6353 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6354 for (
int h = 0; h < H_kv; ++h) {
6355 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6356 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6357 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6358 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
6359 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6373 for (
int h = 0; h < H_kv; ++h) {
6374 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6395 aligned_context_window,
6401 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6418 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6421 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6424 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6436 int aligned_embed_dim,
6437 int aligned_head_dim,
6438 int aligned_intermediate_dim,
6439 int aligned_context_window
6441 const MODELLayerOffsets *L = &MODEL_LAYERS[12];
6443 float *input = MODEL_PTR(model, MODEL_LAYERS[11].output);
6445 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6446 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6447 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6448 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6449 float *k_cache = MODEL_PTR(model, L->k);
6450 float *v_cache = MODEL_PTR(model, L->v);
6451 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6452 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6453 float *residual1 = MODEL_PTR(model, L->residual1);
6454 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6455 float *output = MODEL_PTR(model, L->output);
6458 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6459 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6460 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6461 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6462 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6463 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6465 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6466 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6472 float q_token[H * aligned_head_dim];
6473 float k_token[H_kv * aligned_head_dim];
6474 float v_token[H_kv * aligned_head_dim];
6475 float attn_token[H * aligned_head_dim];
6478 float fc1_out[2 * aligned_intermediate_dim];
6479 float swiglu_out[aligned_intermediate_dim];
6491 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6495 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6496 if (aligned_head_dim > head_dim) {
6497 for (
int h = 0; h < H; ++h) {
6498 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6499 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6506 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6508 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6509 for (
int h = 0; h < H_kv; ++h) {
6510 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6511 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6512 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6513 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
6514 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6520 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6522 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6523 for (
int h = 0; h < H_kv; ++h) {
6524 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6525 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6526 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6527 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
6528 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6542 for (
int h = 0; h < H_kv; ++h) {
6543 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6564 aligned_context_window,
6570 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6587 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6590 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6593 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6605 int aligned_embed_dim,
6606 int aligned_head_dim,
6607 int aligned_intermediate_dim,
6608 int aligned_context_window
6610 const MODELLayerOffsets *L = &MODEL_LAYERS[13];
6612 float *input = MODEL_PTR(model, MODEL_LAYERS[12].output);
6614 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6615 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6616 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6617 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6618 float *k_cache = MODEL_PTR(model, L->k);
6619 float *v_cache = MODEL_PTR(model, L->v);
6620 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6621 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6622 float *residual1 = MODEL_PTR(model, L->residual1);
6623 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6624 float *output = MODEL_PTR(model, L->output);
6627 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6628 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6629 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6630 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6631 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6632 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6634 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6635 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6641 float q_token[H * aligned_head_dim];
6642 float k_token[H_kv * aligned_head_dim];
6643 float v_token[H_kv * aligned_head_dim];
6644 float attn_token[H * aligned_head_dim];
6647 float fc1_out[2 * aligned_intermediate_dim];
6648 float swiglu_out[aligned_intermediate_dim];
6660 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6664 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6665 if (aligned_head_dim > head_dim) {
6666 for (
int h = 0; h < H; ++h) {
6667 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6668 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6675 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6677 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6678 for (
int h = 0; h < H_kv; ++h) {
6679 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6680 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6681 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6682 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
6683 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6689 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6691 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6692 for (
int h = 0; h < H_kv; ++h) {
6693 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6694 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6695 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6696 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
6697 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6711 for (
int h = 0; h < H_kv; ++h) {
6712 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6733 aligned_context_window,
6739 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6756 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6759 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6762 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6774 int aligned_embed_dim,
6775 int aligned_head_dim,
6776 int aligned_intermediate_dim,
6777 int aligned_context_window
6779 const MODELLayerOffsets *L = &MODEL_LAYERS[14];
6781 float *input = MODEL_PTR(model, MODEL_LAYERS[13].output);
6783 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6784 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6785 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6786 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6787 float *k_cache = MODEL_PTR(model, L->k);
6788 float *v_cache = MODEL_PTR(model, L->v);
6789 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6790 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6791 float *residual1 = MODEL_PTR(model, L->residual1);
6792 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6793 float *output = MODEL_PTR(model, L->output);
6796 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6797 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6798 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6799 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6800 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6801 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6803 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6804 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6810 float q_token[H * aligned_head_dim];
6811 float k_token[H_kv * aligned_head_dim];
6812 float v_token[H_kv * aligned_head_dim];
6813 float attn_token[H * aligned_head_dim];
6816 float fc1_out[2 * aligned_intermediate_dim];
6817 float swiglu_out[aligned_intermediate_dim];
6829 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6833 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
6834 if (aligned_head_dim > head_dim) {
6835 for (
int h = 0; h < H; ++h) {
6836 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6837 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6844 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6846 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6847 for (
int h = 0; h < H_kv; ++h) {
6848 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6849 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6850 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6851 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
6852 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6858 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6860 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6861 for (
int h = 0; h < H_kv; ++h) {
6862 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6863 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
6864 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6865 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
6866 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6880 for (
int h = 0; h < H_kv; ++h) {
6881 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6902 aligned_context_window,
6908 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
6925 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
6928 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6931 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
6943 int aligned_embed_dim,
6944 int aligned_head_dim,
6945 int aligned_intermediate_dim,
6946 int aligned_context_window
6948 const MODELLayerOffsets *L = &MODEL_LAYERS[15];
6950 float *input = MODEL_PTR(model, MODEL_LAYERS[14].output);
6952 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6953 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6954 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6955 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6956 float *k_cache = MODEL_PTR(model, L->k);
6957 float *v_cache = MODEL_PTR(model, L->v);
6958 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6959 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6960 float *residual1 = MODEL_PTR(model, L->residual1);
6961 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6962 float *output = MODEL_PTR(model, L->output);
6965 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6966 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6967 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6968 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6969 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6970 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6972 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6973 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6979 float q_token[H * aligned_head_dim];
6980 float k_token[H_kv * aligned_head_dim];
6981 float v_token[H_kv * aligned_head_dim];
6982 float attn_token[H * aligned_head_dim];
6985 float fc1_out[2 * aligned_intermediate_dim];
6986 float swiglu_out[aligned_intermediate_dim];
6998 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7002 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7003 if (aligned_head_dim > head_dim) {
7004 for (
int h = 0; h < H; ++h) {
7005 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7006 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7013 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7015 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7016 for (
int h = 0; h < H_kv; ++h) {
7017 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7018 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7019 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7020 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
7021 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7027 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7029 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7030 for (
int h = 0; h < H_kv; ++h) {
7031 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7032 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7033 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7034 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
7035 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7049 for (
int h = 0; h < H_kv; ++h) {
7050 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7071 aligned_context_window,
7077 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7094 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7097 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7100 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7112 int aligned_embed_dim,
7113 int aligned_head_dim,
7114 int aligned_intermediate_dim,
7115 int aligned_context_window
7117 const MODELLayerOffsets *L = &MODEL_LAYERS[16];
7119 float *input = MODEL_PTR(model, MODEL_LAYERS[15].output);
7121 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7122 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7123 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7124 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7125 float *k_cache = MODEL_PTR(model, L->k);
7126 float *v_cache = MODEL_PTR(model, L->v);
7127 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7128 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7129 float *residual1 = MODEL_PTR(model, L->residual1);
7130 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7131 float *output = MODEL_PTR(model, L->output);
7134 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7135 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7136 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7137 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7138 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7139 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7141 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7142 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7148 float q_token[H * aligned_head_dim];
7149 float k_token[H_kv * aligned_head_dim];
7150 float v_token[H_kv * aligned_head_dim];
7151 float attn_token[H * aligned_head_dim];
7154 float fc1_out[2 * aligned_intermediate_dim];
7155 float swiglu_out[aligned_intermediate_dim];
7167 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7171 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7172 if (aligned_head_dim > head_dim) {
7173 for (
int h = 0; h < H; ++h) {
7174 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7175 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7182 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7184 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7185 for (
int h = 0; h < H_kv; ++h) {
7186 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7187 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7188 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7189 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
7190 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7196 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7198 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7199 for (
int h = 0; h < H_kv; ++h) {
7200 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7201 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7202 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7203 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
7204 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7218 for (
int h = 0; h < H_kv; ++h) {
7219 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7240 aligned_context_window,
7246 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7263 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7266 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7269 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7281 int aligned_embed_dim,
7282 int aligned_head_dim,
7283 int aligned_intermediate_dim,
7284 int aligned_context_window
7286 const MODELLayerOffsets *L = &MODEL_LAYERS[17];
7288 float *input = MODEL_PTR(model, MODEL_LAYERS[16].output);
7290 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7291 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7292 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7293 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7294 float *k_cache = MODEL_PTR(model, L->k);
7295 float *v_cache = MODEL_PTR(model, L->v);
7296 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7297 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7298 float *residual1 = MODEL_PTR(model, L->residual1);
7299 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7300 float *output = MODEL_PTR(model, L->output);
7303 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7304 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7305 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7306 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7307 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7308 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7310 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7311 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7317 float q_token[H * aligned_head_dim];
7318 float k_token[H_kv * aligned_head_dim];
7319 float v_token[H_kv * aligned_head_dim];
7320 float attn_token[H * aligned_head_dim];
7323 float fc1_out[2 * aligned_intermediate_dim];
7324 float swiglu_out[aligned_intermediate_dim];
7336 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7340 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7341 if (aligned_head_dim > head_dim) {
7342 for (
int h = 0; h < H; ++h) {
7343 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7344 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7351 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7353 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7354 for (
int h = 0; h < H_kv; ++h) {
7355 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7356 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7357 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7358 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
7359 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7365 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7367 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7368 for (
int h = 0; h < H_kv; ++h) {
7369 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7370 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7371 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7372 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
7373 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7387 for (
int h = 0; h < H_kv; ++h) {
7388 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7409 aligned_context_window,
7415 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7432 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7435 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7438 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7450 int aligned_embed_dim,
7451 int aligned_head_dim,
7452 int aligned_intermediate_dim,
7453 int aligned_context_window
7455 const MODELLayerOffsets *L = &MODEL_LAYERS[18];
7457 float *input = MODEL_PTR(model, MODEL_LAYERS[17].output);
7459 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7460 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7461 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7462 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7463 float *k_cache = MODEL_PTR(model, L->k);
7464 float *v_cache = MODEL_PTR(model, L->v);
7465 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7466 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7467 float *residual1 = MODEL_PTR(model, L->residual1);
7468 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7469 float *output = MODEL_PTR(model, L->output);
7472 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7473 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7474 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7475 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7476 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7477 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7479 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7480 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7486 float q_token[H * aligned_head_dim];
7487 float k_token[H_kv * aligned_head_dim];
7488 float v_token[H_kv * aligned_head_dim];
7489 float attn_token[H * aligned_head_dim];
7492 float fc1_out[2 * aligned_intermediate_dim];
7493 float swiglu_out[aligned_intermediate_dim];
7505 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7509 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7510 if (aligned_head_dim > head_dim) {
7511 for (
int h = 0; h < H; ++h) {
7512 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7513 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7520 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7522 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7523 for (
int h = 0; h < H_kv; ++h) {
7524 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7525 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7526 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7527 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
7528 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7534 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7536 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7537 for (
int h = 0; h < H_kv; ++h) {
7538 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7539 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7540 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7541 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
7542 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7556 for (
int h = 0; h < H_kv; ++h) {
7557 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7578 aligned_context_window,
7584 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7601 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7604 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7607 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7619 int aligned_embed_dim,
7620 int aligned_head_dim,
7621 int aligned_intermediate_dim,
7622 int aligned_context_window
7624 const MODELLayerOffsets *L = &MODEL_LAYERS[19];
7626 float *input = MODEL_PTR(model, MODEL_LAYERS[18].output);
7628 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7629 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7630 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7631 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7632 float *k_cache = MODEL_PTR(model, L->k);
7633 float *v_cache = MODEL_PTR(model, L->v);
7634 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7635 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7636 float *residual1 = MODEL_PTR(model, L->residual1);
7637 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7638 float *output = MODEL_PTR(model, L->output);
7641 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7642 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7643 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7644 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7645 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7646 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7648 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7649 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7655 float q_token[H * aligned_head_dim];
7656 float k_token[H_kv * aligned_head_dim];
7657 float v_token[H_kv * aligned_head_dim];
7658 float attn_token[H * aligned_head_dim];
7661 float fc1_out[2 * aligned_intermediate_dim];
7662 float swiglu_out[aligned_intermediate_dim];
7674 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7678 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7679 if (aligned_head_dim > head_dim) {
7680 for (
int h = 0; h < H; ++h) {
7681 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7682 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7689 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7691 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7692 for (
int h = 0; h < H_kv; ++h) {
7693 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7694 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7695 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7696 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
7697 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7703 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7705 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7706 for (
int h = 0; h < H_kv; ++h) {
7707 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7708 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7709 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7710 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
7711 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7725 for (
int h = 0; h < H_kv; ++h) {
7726 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7747 aligned_context_window,
7753 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7770 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7773 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7776 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7788 int aligned_embed_dim,
7789 int aligned_head_dim,
7790 int aligned_intermediate_dim,
7791 int aligned_context_window
7793 const MODELLayerOffsets *L = &MODEL_LAYERS[20];
7795 float *input = MODEL_PTR(model, MODEL_LAYERS[19].output);
7797 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7798 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7799 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7800 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7801 float *k_cache = MODEL_PTR(model, L->k);
7802 float *v_cache = MODEL_PTR(model, L->v);
7803 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7804 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7805 float *residual1 = MODEL_PTR(model, L->residual1);
7806 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7807 float *output = MODEL_PTR(model, L->output);
7810 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7811 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7812 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7813 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7814 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7815 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7817 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7818 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7824 float q_token[H * aligned_head_dim];
7825 float k_token[H_kv * aligned_head_dim];
7826 float v_token[H_kv * aligned_head_dim];
7827 float attn_token[H * aligned_head_dim];
7830 float fc1_out[2 * aligned_intermediate_dim];
7831 float swiglu_out[aligned_intermediate_dim];
7843 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7847 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
7848 if (aligned_head_dim > head_dim) {
7849 for (
int h = 0; h < H; ++h) {
7850 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7851 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7858 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7860 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7861 for (
int h = 0; h < H_kv; ++h) {
7862 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7863 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7864 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7865 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
7866 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7872 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7874 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7875 for (
int h = 0; h < H_kv; ++h) {
7876 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7877 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
7878 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7879 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
7880 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7894 for (
int h = 0; h < H_kv; ++h) {
7895 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7916 aligned_context_window,
7922 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
7939 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
7942 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7945 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
7957 int aligned_embed_dim,
7958 int aligned_head_dim,
7959 int aligned_intermediate_dim,
7960 int aligned_context_window
7962 const MODELLayerOffsets *L = &MODEL_LAYERS[21];
7964 float *input = MODEL_PTR(model, MODEL_LAYERS[20].output);
7966 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7967 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7968 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7969 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7970 float *k_cache = MODEL_PTR(model, L->k);
7971 float *v_cache = MODEL_PTR(model, L->v);
7972 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7973 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7974 float *residual1 = MODEL_PTR(model, L->residual1);
7975 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7976 float *output = MODEL_PTR(model, L->output);
7979 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7980 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7981 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7982 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7983 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7984 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7986 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7987 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7993 float q_token[H * aligned_head_dim];
7994 float k_token[H_kv * aligned_head_dim];
7995 float v_token[H_kv * aligned_head_dim];
7996 float attn_token[H * aligned_head_dim];
7999 float fc1_out[2 * aligned_intermediate_dim];
8000 float swiglu_out[aligned_intermediate_dim];
8012 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8016 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
8017 if (aligned_head_dim > head_dim) {
8018 for (
int h = 0; h < H; ++h) {
8019 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8020 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8027 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8029 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8030 for (
int h = 0; h < H_kv; ++h) {
8031 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8032 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
8033 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8034 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
8035 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8041 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8043 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8044 for (
int h = 0; h < H_kv; ++h) {
8045 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8046 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
8047 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8048 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
8049 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8063 for (
int h = 0; h < H_kv; ++h) {
8064 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8085 aligned_context_window,
8091 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
8108 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
8111 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8114 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
8126 int aligned_embed_dim,
8127 int aligned_head_dim,
8128 int aligned_intermediate_dim,
8129 int aligned_context_window
8131 const MODELLayerOffsets *L = &MODEL_LAYERS[22];
8133 float *input = MODEL_PTR(model, MODEL_LAYERS[21].output);
8135 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
8136 float *ln1_out = MODEL_PTR(model, L->ln1_out);
8137 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
8138 float *ln2_out = MODEL_PTR(model, L->ln2_out);
8139 float *k_cache = MODEL_PTR(model, L->k);
8140 float *v_cache = MODEL_PTR(model, L->v);
8141 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
8142 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
8143 float *residual1 = MODEL_PTR(model, L->residual1);
8144 float *mlp_out = MODEL_PTR(model, L->mlp_out);
8145 float *output = MODEL_PTR(model, L->output);
8148 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
8149 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
8150 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
8151 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
8152 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
8153 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
8155 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
8156 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
8162 float q_token[H * aligned_head_dim];
8163 float k_token[H_kv * aligned_head_dim];
8164 float v_token[H_kv * aligned_head_dim];
8165 float attn_token[H * aligned_head_dim];
8168 float fc1_out[2 * aligned_intermediate_dim];
8169 float swiglu_out[aligned_intermediate_dim];
8181 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8185 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
8186 if (aligned_head_dim > head_dim) {
8187 for (
int h = 0; h < H; ++h) {
8188 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8189 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8196 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8198 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8199 for (
int h = 0; h < H_kv; ++h) {
8200 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8201 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
8202 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8203 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
8204 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8210 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8212 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8213 for (
int h = 0; h < H_kv; ++h) {
8214 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8215 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
8216 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8217 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
8218 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8232 for (
int h = 0; h < H_kv; ++h) {
8233 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8254 aligned_context_window,
8260 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
8277 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
8280 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8283 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
8295 int aligned_embed_dim,
8296 int aligned_head_dim,
8297 int aligned_intermediate_dim,
8298 int aligned_context_window
8300 const MODELLayerOffsets *L = &MODEL_LAYERS[23];
8302 float *input = MODEL_PTR(model, MODEL_LAYERS[22].output);
8304 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
8305 float *ln1_out = MODEL_PTR(model, L->ln1_out);
8306 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
8307 float *ln2_out = MODEL_PTR(model, L->ln2_out);
8308 float *k_cache = MODEL_PTR(model, L->k);
8309 float *v_cache = MODEL_PTR(model, L->v);
8310 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
8311 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
8312 float *residual1 = MODEL_PTR(model, L->residual1);
8313 float *mlp_out = MODEL_PTR(model, L->mlp_out);
8314 float *output = MODEL_PTR(model, L->output);
8317 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
8318 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
8319 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
8320 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
8321 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
8322 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
8324 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
8325 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
8331 float q_token[H * aligned_head_dim];
8332 float k_token[H_kv * aligned_head_dim];
8333 float v_token[H_kv * aligned_head_dim];
8334 float attn_token[H * aligned_head_dim];
8337 float fc1_out[2 * aligned_intermediate_dim];
8338 float swiglu_out[aligned_intermediate_dim];
8350 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8354 gemm_nt_q4_k(ln1_out, WQ, NULL, q_token, 1, H * head_dim, aligned_embed_dim);
8355 if (aligned_head_dim > head_dim) {
8356 for (
int h = 0; h < H; ++h) {
8357 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8358 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8365 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8367 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8368 for (
int h = 0; h < H_kv; ++h) {
8369 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8370 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
8371 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8372 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_head, 1, head_dim, aligned_embed_dim);
8373 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8379 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8381 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8382 for (
int h = 0; h < H_kv; ++h) {
8383 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8384 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
8385 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8386 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_head, 1, head_dim, aligned_embed_dim);
8387 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8401 for (
int h = 0; h < H_kv; ++h) {
8402 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8423 aligned_context_window,
8429 gemm_nt_q4_k(attn_token, WO, NULL, proj_tmp, 1, aligned_embed_dim, H * head_dim);
8446 gemm_nt_q4_k(ln2_out, W1, NULL, fc1_out, 1, 2 * aligned_intermediate_dim, aligned_embed_dim);
8449 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8452 gemm_nt_q4_k(swiglu_out, W2, NULL, mlp_out, 1, aligned_embed_dim, aligned_intermediate_dim);
8467 if (!model || !
token)
return;
8469 const int aligned_embed_dim = 1024;
8470 const int aligned_head_dim = 64;
8471 const int aligned_intermediate_dim = 4864;
8472 const int aligned_context_window = 131072;
8474 if (token_index < 0 || token_index >= aligned_context_window)
return;
8477 float *embed_out = MODEL_PTR(model, MODEL_HEADER.embedded_input);
8478 const void *embed_weight = (
const void *)MODEL_PTR(model, MODEL_HEADER.token_emb);
8492 model_layer_0_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8493 model_layer_1_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8494 model_layer_2_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8495 model_layer_3_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8496 model_layer_4_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8497 model_layer_5_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8498 model_layer_6_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8499 model_layer_7_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8500 model_layer_8_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8501 model_layer_9_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8502 model_layer_10_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8503 model_layer_11_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8504 model_layer_12_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8505 model_layer_13_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8506 model_layer_14_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8507 model_layer_15_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8508 model_layer_16_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8509 model_layer_17_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8510 model_layer_18_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8511 model_layer_19_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8512 model_layer_20_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8513 model_layer_21_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8514 model_layer_22_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8515 model_layer_23_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8518 float *last_hidden = MODEL_PTR(model, MODEL_LAYERS[23].output);
8519 float *final_ln_weight = MODEL_PTR(model, MODEL_FOOTER.final_ln_weight);
8520 float *final_out = MODEL_PTR(model, MODEL_FOOTER.final_output);
8531 float *logits = MODEL_PTR(model, MODEL_FOOTER.logits);
8532 const void *lm_head = (
const void *)MODEL_PTR(model, MODEL_FOOTER.lm_head_weight);
8546 if (!model || !tokens || num_tokens <= 0)
return;
8571 .total_bytes = MODEL_TOTAL_BYTES,
8572 .weight_bytes = MODEL_WEIGHT_BYTES,
8573 .activation_bytes = MODEL_ACTIVATION_BYTES,
8574 .model_name =
"model",
8575 .model_family =
"model",
8583 MODELModel *model = malloc(
sizeof(MODELModel));
8584 if (!model)
return NULL;
8611 MODELModel *m = (MODELModel *)model;
8612 return MODEL_PTR(m, MODEL_FOOTER.logits);
8620 return ((MODELModel *)model)->base;
8624 return ((MODELModel *)model)->total_bytes;
Generic Model API - Model-agnostic interface for CK-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void quantize_row_q8_k(const float *x, void *y, int k)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
static void model_layer_13_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
const CKModelConfig * ck_model_get_config(void)
static void model_layer_17_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_13_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_6_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int model_model_allocate(MODELModel *model)
static void model_layer_3_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_0_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_15_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_8_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_22_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_forward_prefill_impl(MODELModel *model, const int *tokens, int num_tokens)
static void model_layer_22_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_18_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void * ck_model_create(void)
static void model_layer_8_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static CKModelConfig g_model_config
static void model_layer_7_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void model_decode(MODELModel *model, const int *token, int token_index)
void ck_model_precompute_rope(void *model)
void model_precompute_rope(MODELModel *model)
static void model_layer_4_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_11_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void ck_model_forward(void *model, const int *tokens, int num_tokens)
static void model_layer_12_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_10_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_16_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void model_forward(MODELModel *model, const int *tokens, int num_tokens)
int ck_model_verify_canaries(void *model)
static void model_layer_21_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_16_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_4_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_1_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_14_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int model_verify_canaries(MODELModel *model)
static void model_layer_20_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_3_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_19_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void ck_model_free(void *model)
void * ck_model_get_base(void *model)
void ck_model_decode(void *model, const int *token, int token_index)
static void model_layer_14_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
void model_model_free(MODELModel *model)
static void model_layer_20_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
size_t ck_model_get_total_bytes(void *model)
float * ck_model_get_logits(void *model)
static void model_layer_11_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_7_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
struct __attribute__((packed))
static void model_layer_23_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_21_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_5_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_18_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_5_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_15_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_10_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_12_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_17_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_0_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static int model_align_elems(int elems, int elem_bytes, int align_bytes)
static void model_layer_1_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_23_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_decode_token(MODELModel *model, const int *token, int token_index)
static void model_layer_19_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
_Static_assert(sizeof(MagicHeader)==64, "MagicHeader must be 64 bytes")
static void model_layer_2_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_2_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_9_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_6_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_9_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
AUTO-GENERATED: qwen2_0.5b_decode Memory Layout.
#define MODEL_INTERMEDIATE
#define MODEL_NUM_KV_HEADS
#define MODEL_MAX_SEQ_LEN