35 #if MODEL_DTYPE_BYTES != 4
36 #error "model: v6.5 codegen currently supports fp32 only. Use --dtype=fp32."
50 if (!a || !b || !out) {
53 for (
int t = 0; t < tokens; ++t) {
54 const float *pa = a + (size_t)t * (
size_t)aligned_embed_dim;
55 const float *pb = b + (size_t)t * (
size_t)aligned_embed_dim;
56 float *pc = out + (size_t)t * (
size_t)aligned_embed_dim;
57 for (
int d = 0; d < aligned_embed_dim; ++d) {
58 pc[d] = pa[d] + pb[d];
71 uint64_t weight_bytes;
72 uint64_t activation_bytes;
78 uint32_t canary_count;
89 size_t total = MODEL_TOTAL_BYTES;
92 model->base = mmap(NULL, total,
93 PROT_READ | PROT_WRITE,
94 MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB,
96 if (model->base == MAP_FAILED) {
97 model->base = mmap(NULL, total,
98 PROT_READ | PROT_WRITE,
99 MAP_PRIVATE | MAP_ANONYMOUS,
102 if (model->base == MAP_FAILED) {
103 perror(
"mmap failed");
107 model->base = aligned_alloc(64, total);
109 perror(
"aligned_alloc failed");
114 model->total_bytes = total;
118 header->magic = MODEL_MAGIC;
120 header->total_bytes = MODEL_TOTAL_BYTES;
121 header->weight_bytes = MODEL_WEIGHT_BYTES;
122 header->activation_bytes = MODEL_ACTIVATION_BYTES;
128 header->canary_count = MODEL_CANARY_COUNT;
131 for (
int i = 0; i < MODEL_CANARY_COUNT; i++) {
132 uint32_t *ptr = (uint32_t*)((
char*)model->base + MODEL_CANARIES[i].offset);
133 for (
int j = 0; j < (MODEL_CANARY_SIZE / 4); j++) {
134 ptr[j] = MODEL_CANARY_VALUE;
142 if (!model || !model->base)
return;
144 munmap(model->base, model->total_bytes);
149 model->total_bytes = 0;
156 for (
int i = 0; i < MODEL_CANARY_COUNT; i++) {
157 ptr = (uint32_t*)((
char*)model->base + MODEL_CANARIES[i].offset);
158 for (
int j = 0; j < 4; j++) {
159 if (ptr[j] != MODEL_CANARY_VALUE) {
160 fprintf(stderr,
"CANARY CORRUPTION: %s at offset 0x%lX\n",
161 MODEL_CANARIES[i].name,
162 MODEL_CANARIES[i].offset);
177 int bytes = elems * elem_bytes;
178 int aligned = (bytes + align_bytes - 1) / align_bytes * align_bytes;
179 return aligned / elem_bytes;
189 const float theta = 1000000.0f;
191 float *cos_ptr = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
192 float *sin_ptr = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
194 for (
int pos = 0; pos < T; pos++) {
195 for (
int i = 0; i < D; i++) {
196 float freq = 1.0f / powf(theta, (
float)(2 * i) / (
float)(D * 2));
197 float angle = (float)pos * freq;
198 cos_ptr[pos * D + i] = cosf(angle);
199 sin_ptr[pos * D + i] = sinf(angle);
214 int aligned_embed_dim,
215 int aligned_head_dim,
216 int aligned_intermediate_dim,
217 int aligned_context_window
219 const MODELLayerOffsets *L = &MODEL_LAYERS[0];
221 float *input = MODEL_PTR(model, MODEL_HEADER.embedded_input);
222 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
223 float *ln1_out = MODEL_PTR(model, L->ln1_out);
224 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
225 float *ln2_out = MODEL_PTR(model, L->ln2_out);
226 float *q = MODEL_PTR(model, L->q);
227 float *k = MODEL_PTR(model, L->k);
228 float *v = MODEL_PTR(model, L->v);
229 float *attn_out = MODEL_PTR(model, L->attn_out);
230 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
231 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
232 float *residual1 = MODEL_PTR(model, L->residual1);
233 float *fc1_out = MODEL_PTR(model, L->fc1_out);
234 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
235 float *mlp_out = MODEL_PTR(model, L->mlp_out);
236 float *output = MODEL_PTR(model, L->output);
238 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
239 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
240 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
241 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
242 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
243 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
244 const float *BQ = NULL;
245 const float *BK = NULL;
246 const float *BV = NULL;
247 const float *BO = NULL;
248 const float *B1 = NULL;
249 const float *B2 = NULL;
251 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
252 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
257 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
258 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
259 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
273 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
274 for (
int h = 0; h < H; ++h) {
275 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
276 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
277 float *q_h = q + (size_t)h * q_head_stride;
278 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
283 const uint8_t *WK_bytes = (
const uint8_t *)WK;
284 for (
int h = 0; h < H_kv; ++h) {
285 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
286 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
287 float *k_h = k + (size_t)h * kv_head_stride;
288 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
293 const uint8_t *WV_bytes = (
const uint8_t *)WV;
294 for (
int h = 0; h < H_kv; ++h) {
295 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
296 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
297 float *v_h = v + (size_t)h * kv_head_stride;
298 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
313 aligned_context_window);
325 aligned_context_window);
328 const int K = H * aligned_head_dim;
329 if (K != aligned_embed_dim) {
332 const float *proj_in = attn_out;
337 for (
int t = 0; t < num_tokens; ++t) {
338 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
339 for (
int h = 0; h < H; ++h) {
340 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
341 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
343 (
size_t)aligned_head_dim *
sizeof(
float));
346 proj_in = proj_scratch;
348 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
364 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
365 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
366 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
378 int aligned_embed_dim,
379 int aligned_head_dim,
380 int aligned_intermediate_dim,
381 int aligned_context_window
383 const MODELLayerOffsets *L = &MODEL_LAYERS[1];
385 float *input = MODEL_PTR(model, MODEL_LAYERS[0].output);
386 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
387 float *ln1_out = MODEL_PTR(model, L->ln1_out);
388 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
389 float *ln2_out = MODEL_PTR(model, L->ln2_out);
390 float *q = MODEL_PTR(model, L->q);
391 float *k = MODEL_PTR(model, L->k);
392 float *v = MODEL_PTR(model, L->v);
393 float *attn_out = MODEL_PTR(model, L->attn_out);
394 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
395 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
396 float *residual1 = MODEL_PTR(model, L->residual1);
397 float *fc1_out = MODEL_PTR(model, L->fc1_out);
398 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
399 float *mlp_out = MODEL_PTR(model, L->mlp_out);
400 float *output = MODEL_PTR(model, L->output);
402 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
403 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
404 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
405 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
406 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
407 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
408 const float *BQ = NULL;
409 const float *BK = NULL;
410 const float *BV = NULL;
411 const float *BO = NULL;
412 const float *B1 = NULL;
413 const float *B2 = NULL;
415 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
416 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
421 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
422 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
423 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
437 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
438 for (
int h = 0; h < H; ++h) {
439 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
440 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
441 float *q_h = q + (size_t)h * q_head_stride;
442 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
447 const uint8_t *WK_bytes = (
const uint8_t *)WK;
448 for (
int h = 0; h < H_kv; ++h) {
449 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
450 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
451 float *k_h = k + (size_t)h * kv_head_stride;
452 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
457 const uint8_t *WV_bytes = (
const uint8_t *)WV;
458 for (
int h = 0; h < H_kv; ++h) {
459 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
460 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
461 float *v_h = v + (size_t)h * kv_head_stride;
462 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
477 aligned_context_window);
489 aligned_context_window);
492 const int K = H * aligned_head_dim;
493 if (K != aligned_embed_dim) {
496 const float *proj_in = attn_out;
501 for (
int t = 0; t < num_tokens; ++t) {
502 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
503 for (
int h = 0; h < H; ++h) {
504 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
505 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
507 (
size_t)aligned_head_dim *
sizeof(
float));
510 proj_in = proj_scratch;
512 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
528 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
529 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
530 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
542 int aligned_embed_dim,
543 int aligned_head_dim,
544 int aligned_intermediate_dim,
545 int aligned_context_window
547 const MODELLayerOffsets *L = &MODEL_LAYERS[2];
549 float *input = MODEL_PTR(model, MODEL_LAYERS[1].output);
550 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
551 float *ln1_out = MODEL_PTR(model, L->ln1_out);
552 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
553 float *ln2_out = MODEL_PTR(model, L->ln2_out);
554 float *q = MODEL_PTR(model, L->q);
555 float *k = MODEL_PTR(model, L->k);
556 float *v = MODEL_PTR(model, L->v);
557 float *attn_out = MODEL_PTR(model, L->attn_out);
558 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
559 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
560 float *residual1 = MODEL_PTR(model, L->residual1);
561 float *fc1_out = MODEL_PTR(model, L->fc1_out);
562 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
563 float *mlp_out = MODEL_PTR(model, L->mlp_out);
564 float *output = MODEL_PTR(model, L->output);
566 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
567 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
568 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
569 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
570 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
571 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
572 const float *BQ = NULL;
573 const float *BK = NULL;
574 const float *BV = NULL;
575 const float *BO = NULL;
576 const float *B1 = NULL;
577 const float *B2 = NULL;
579 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
580 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
585 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
586 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
587 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
601 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
602 for (
int h = 0; h < H; ++h) {
603 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
604 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
605 float *q_h = q + (size_t)h * q_head_stride;
606 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
611 const uint8_t *WK_bytes = (
const uint8_t *)WK;
612 for (
int h = 0; h < H_kv; ++h) {
613 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
614 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
615 float *k_h = k + (size_t)h * kv_head_stride;
616 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
621 const uint8_t *WV_bytes = (
const uint8_t *)WV;
622 for (
int h = 0; h < H_kv; ++h) {
623 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
624 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
625 float *v_h = v + (size_t)h * kv_head_stride;
626 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
641 aligned_context_window);
653 aligned_context_window);
656 const int K = H * aligned_head_dim;
657 if (K != aligned_embed_dim) {
660 const float *proj_in = attn_out;
665 for (
int t = 0; t < num_tokens; ++t) {
666 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
667 for (
int h = 0; h < H; ++h) {
668 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
669 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
671 (
size_t)aligned_head_dim *
sizeof(
float));
674 proj_in = proj_scratch;
676 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
692 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
693 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
694 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
706 int aligned_embed_dim,
707 int aligned_head_dim,
708 int aligned_intermediate_dim,
709 int aligned_context_window
711 const MODELLayerOffsets *L = &MODEL_LAYERS[3];
713 float *input = MODEL_PTR(model, MODEL_LAYERS[2].output);
714 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
715 float *ln1_out = MODEL_PTR(model, L->ln1_out);
716 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
717 float *ln2_out = MODEL_PTR(model, L->ln2_out);
718 float *q = MODEL_PTR(model, L->q);
719 float *k = MODEL_PTR(model, L->k);
720 float *v = MODEL_PTR(model, L->v);
721 float *attn_out = MODEL_PTR(model, L->attn_out);
722 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
723 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
724 float *residual1 = MODEL_PTR(model, L->residual1);
725 float *fc1_out = MODEL_PTR(model, L->fc1_out);
726 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
727 float *mlp_out = MODEL_PTR(model, L->mlp_out);
728 float *output = MODEL_PTR(model, L->output);
730 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
731 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
732 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
733 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
734 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
735 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
736 const float *BQ = NULL;
737 const float *BK = NULL;
738 const float *BV = NULL;
739 const float *BO = NULL;
740 const float *B1 = NULL;
741 const float *B2 = NULL;
743 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
744 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
749 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
750 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
751 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
765 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
766 for (
int h = 0; h < H; ++h) {
767 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
768 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
769 float *q_h = q + (size_t)h * q_head_stride;
770 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
775 const uint8_t *WK_bytes = (
const uint8_t *)WK;
776 for (
int h = 0; h < H_kv; ++h) {
777 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
778 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
779 float *k_h = k + (size_t)h * kv_head_stride;
780 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
785 const uint8_t *WV_bytes = (
const uint8_t *)WV;
786 for (
int h = 0; h < H_kv; ++h) {
787 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
788 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
789 float *v_h = v + (size_t)h * kv_head_stride;
790 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
805 aligned_context_window);
817 aligned_context_window);
820 const int K = H * aligned_head_dim;
821 if (K != aligned_embed_dim) {
824 const float *proj_in = attn_out;
829 for (
int t = 0; t < num_tokens; ++t) {
830 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
831 for (
int h = 0; h < H; ++h) {
832 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
833 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
835 (
size_t)aligned_head_dim *
sizeof(
float));
838 proj_in = proj_scratch;
840 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
856 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
857 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
858 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
870 int aligned_embed_dim,
871 int aligned_head_dim,
872 int aligned_intermediate_dim,
873 int aligned_context_window
875 const MODELLayerOffsets *L = &MODEL_LAYERS[4];
877 float *input = MODEL_PTR(model, MODEL_LAYERS[3].output);
878 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
879 float *ln1_out = MODEL_PTR(model, L->ln1_out);
880 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
881 float *ln2_out = MODEL_PTR(model, L->ln2_out);
882 float *q = MODEL_PTR(model, L->q);
883 float *k = MODEL_PTR(model, L->k);
884 float *v = MODEL_PTR(model, L->v);
885 float *attn_out = MODEL_PTR(model, L->attn_out);
886 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
887 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
888 float *residual1 = MODEL_PTR(model, L->residual1);
889 float *fc1_out = MODEL_PTR(model, L->fc1_out);
890 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
891 float *mlp_out = MODEL_PTR(model, L->mlp_out);
892 float *output = MODEL_PTR(model, L->output);
894 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
895 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
896 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
897 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
898 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
899 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
900 const float *BQ = NULL;
901 const float *BK = NULL;
902 const float *BV = NULL;
903 const float *BO = NULL;
904 const float *B1 = NULL;
905 const float *B2 = NULL;
907 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
908 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
913 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
914 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
915 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
929 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
930 for (
int h = 0; h < H; ++h) {
931 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
932 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
933 float *q_h = q + (size_t)h * q_head_stride;
934 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
939 const uint8_t *WK_bytes = (
const uint8_t *)WK;
940 for (
int h = 0; h < H_kv; ++h) {
941 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
942 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
943 float *k_h = k + (size_t)h * kv_head_stride;
944 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
949 const uint8_t *WV_bytes = (
const uint8_t *)WV;
950 for (
int h = 0; h < H_kv; ++h) {
951 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
952 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
953 float *v_h = v + (size_t)h * kv_head_stride;
954 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
969 aligned_context_window);
981 aligned_context_window);
984 const int K = H * aligned_head_dim;
985 if (K != aligned_embed_dim) {
988 const float *proj_in = attn_out;
993 for (
int t = 0; t < num_tokens; ++t) {
994 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
995 for (
int h = 0; h < H; ++h) {
996 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
997 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
999 (
size_t)aligned_head_dim *
sizeof(
float));
1002 proj_in = proj_scratch;
1004 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1020 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1021 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1022 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1034 int aligned_embed_dim,
1035 int aligned_head_dim,
1036 int aligned_intermediate_dim,
1037 int aligned_context_window
1039 const MODELLayerOffsets *L = &MODEL_LAYERS[5];
1041 float *input = MODEL_PTR(model, MODEL_LAYERS[4].output);
1042 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1043 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1044 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1045 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1046 float *q = MODEL_PTR(model, L->q);
1047 float *k = MODEL_PTR(model, L->k);
1048 float *v = MODEL_PTR(model, L->v);
1049 float *attn_out = MODEL_PTR(model, L->attn_out);
1050 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1051 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1052 float *residual1 = MODEL_PTR(model, L->residual1);
1053 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1054 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1055 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1056 float *output = MODEL_PTR(model, L->output);
1058 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1059 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1060 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1061 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1062 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1063 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1064 const float *BQ = NULL;
1065 const float *BK = NULL;
1066 const float *BV = NULL;
1067 const float *BO = NULL;
1068 const float *B1 = NULL;
1069 const float *B2 = NULL;
1071 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1072 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1077 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1078 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1079 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1093 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1094 for (
int h = 0; h < H; ++h) {
1095 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1096 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1097 float *q_h = q + (size_t)h * q_head_stride;
1098 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1103 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1104 for (
int h = 0; h < H_kv; ++h) {
1105 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1106 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1107 float *k_h = k + (size_t)h * kv_head_stride;
1108 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1113 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1114 for (
int h = 0; h < H_kv; ++h) {
1115 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1116 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1117 float *v_h = v + (size_t)h * kv_head_stride;
1118 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1133 aligned_context_window);
1145 aligned_context_window);
1148 const int K = H * aligned_head_dim;
1149 if (K != aligned_embed_dim) {
1152 const float *proj_in = attn_out;
1154 if (!proj_scratch) {
1157 for (
int t = 0; t < num_tokens; ++t) {
1158 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1159 for (
int h = 0; h < H; ++h) {
1160 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1161 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1163 (
size_t)aligned_head_dim *
sizeof(
float));
1166 proj_in = proj_scratch;
1168 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1184 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1185 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1186 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1198 int aligned_embed_dim,
1199 int aligned_head_dim,
1200 int aligned_intermediate_dim,
1201 int aligned_context_window
1203 const MODELLayerOffsets *L = &MODEL_LAYERS[6];
1205 float *input = MODEL_PTR(model, MODEL_LAYERS[5].output);
1206 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1207 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1208 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1209 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1210 float *q = MODEL_PTR(model, L->q);
1211 float *k = MODEL_PTR(model, L->k);
1212 float *v = MODEL_PTR(model, L->v);
1213 float *attn_out = MODEL_PTR(model, L->attn_out);
1214 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1215 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1216 float *residual1 = MODEL_PTR(model, L->residual1);
1217 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1218 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1219 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1220 float *output = MODEL_PTR(model, L->output);
1222 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1223 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1224 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1225 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1226 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1227 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1228 const float *BQ = NULL;
1229 const float *BK = NULL;
1230 const float *BV = NULL;
1231 const float *BO = NULL;
1232 const float *B1 = NULL;
1233 const float *B2 = NULL;
1235 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1236 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1241 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1242 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1243 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1257 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1258 for (
int h = 0; h < H; ++h) {
1259 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1260 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1261 float *q_h = q + (size_t)h * q_head_stride;
1262 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1267 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1268 for (
int h = 0; h < H_kv; ++h) {
1269 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1270 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1271 float *k_h = k + (size_t)h * kv_head_stride;
1272 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1277 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1278 for (
int h = 0; h < H_kv; ++h) {
1279 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1280 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1281 float *v_h = v + (size_t)h * kv_head_stride;
1282 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1297 aligned_context_window);
1309 aligned_context_window);
1312 const int K = H * aligned_head_dim;
1313 if (K != aligned_embed_dim) {
1316 const float *proj_in = attn_out;
1318 if (!proj_scratch) {
1321 for (
int t = 0; t < num_tokens; ++t) {
1322 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1323 for (
int h = 0; h < H; ++h) {
1324 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1325 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1327 (
size_t)aligned_head_dim *
sizeof(
float));
1330 proj_in = proj_scratch;
1332 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1348 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1349 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1350 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1362 int aligned_embed_dim,
1363 int aligned_head_dim,
1364 int aligned_intermediate_dim,
1365 int aligned_context_window
1367 const MODELLayerOffsets *L = &MODEL_LAYERS[7];
1369 float *input = MODEL_PTR(model, MODEL_LAYERS[6].output);
1370 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1371 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1372 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1373 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1374 float *q = MODEL_PTR(model, L->q);
1375 float *k = MODEL_PTR(model, L->k);
1376 float *v = MODEL_PTR(model, L->v);
1377 float *attn_out = MODEL_PTR(model, L->attn_out);
1378 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1379 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1380 float *residual1 = MODEL_PTR(model, L->residual1);
1381 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1382 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1383 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1384 float *output = MODEL_PTR(model, L->output);
1386 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1387 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1388 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1389 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1390 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1391 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1392 const float *BQ = NULL;
1393 const float *BK = NULL;
1394 const float *BV = NULL;
1395 const float *BO = NULL;
1396 const float *B1 = NULL;
1397 const float *B2 = NULL;
1399 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1400 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1405 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1406 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1407 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1421 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1422 for (
int h = 0; h < H; ++h) {
1423 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1424 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1425 float *q_h = q + (size_t)h * q_head_stride;
1426 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1431 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1432 for (
int h = 0; h < H_kv; ++h) {
1433 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1434 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1435 float *k_h = k + (size_t)h * kv_head_stride;
1436 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1441 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1442 for (
int h = 0; h < H_kv; ++h) {
1443 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1444 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1445 float *v_h = v + (size_t)h * kv_head_stride;
1446 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1461 aligned_context_window);
1473 aligned_context_window);
1476 const int K = H * aligned_head_dim;
1477 if (K != aligned_embed_dim) {
1480 const float *proj_in = attn_out;
1482 if (!proj_scratch) {
1485 for (
int t = 0; t < num_tokens; ++t) {
1486 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1487 for (
int h = 0; h < H; ++h) {
1488 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1489 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1491 (
size_t)aligned_head_dim *
sizeof(
float));
1494 proj_in = proj_scratch;
1496 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1512 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1513 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1514 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1526 int aligned_embed_dim,
1527 int aligned_head_dim,
1528 int aligned_intermediate_dim,
1529 int aligned_context_window
1531 const MODELLayerOffsets *L = &MODEL_LAYERS[8];
1533 float *input = MODEL_PTR(model, MODEL_LAYERS[7].output);
1534 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1535 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1536 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1537 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1538 float *q = MODEL_PTR(model, L->q);
1539 float *k = MODEL_PTR(model, L->k);
1540 float *v = MODEL_PTR(model, L->v);
1541 float *attn_out = MODEL_PTR(model, L->attn_out);
1542 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1543 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1544 float *residual1 = MODEL_PTR(model, L->residual1);
1545 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1546 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1547 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1548 float *output = MODEL_PTR(model, L->output);
1550 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1551 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1552 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1553 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1554 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1555 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1556 const float *BQ = NULL;
1557 const float *BK = NULL;
1558 const float *BV = NULL;
1559 const float *BO = NULL;
1560 const float *B1 = NULL;
1561 const float *B2 = NULL;
1563 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1564 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1569 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1570 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1571 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1585 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1586 for (
int h = 0; h < H; ++h) {
1587 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1588 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1589 float *q_h = q + (size_t)h * q_head_stride;
1590 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1595 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1596 for (
int h = 0; h < H_kv; ++h) {
1597 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1598 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1599 float *k_h = k + (size_t)h * kv_head_stride;
1600 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1605 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1606 for (
int h = 0; h < H_kv; ++h) {
1607 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1608 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1609 float *v_h = v + (size_t)h * kv_head_stride;
1610 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1625 aligned_context_window);
1637 aligned_context_window);
1640 const int K = H * aligned_head_dim;
1641 if (K != aligned_embed_dim) {
1644 const float *proj_in = attn_out;
1646 if (!proj_scratch) {
1649 for (
int t = 0; t < num_tokens; ++t) {
1650 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1651 for (
int h = 0; h < H; ++h) {
1652 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1653 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1655 (
size_t)aligned_head_dim *
sizeof(
float));
1658 proj_in = proj_scratch;
1660 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1676 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1677 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1678 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1690 int aligned_embed_dim,
1691 int aligned_head_dim,
1692 int aligned_intermediate_dim,
1693 int aligned_context_window
1695 const MODELLayerOffsets *L = &MODEL_LAYERS[9];
1697 float *input = MODEL_PTR(model, MODEL_LAYERS[8].output);
1698 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1699 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1700 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1701 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1702 float *q = MODEL_PTR(model, L->q);
1703 float *k = MODEL_PTR(model, L->k);
1704 float *v = MODEL_PTR(model, L->v);
1705 float *attn_out = MODEL_PTR(model, L->attn_out);
1706 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1707 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1708 float *residual1 = MODEL_PTR(model, L->residual1);
1709 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1710 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1711 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1712 float *output = MODEL_PTR(model, L->output);
1714 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1715 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1716 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1717 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1718 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1719 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1720 const float *BQ = NULL;
1721 const float *BK = NULL;
1722 const float *BV = NULL;
1723 const float *BO = NULL;
1724 const float *B1 = NULL;
1725 const float *B2 = NULL;
1727 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1728 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1733 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1734 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1735 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1749 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1750 for (
int h = 0; h < H; ++h) {
1751 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1752 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1753 float *q_h = q + (size_t)h * q_head_stride;
1754 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1759 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1760 for (
int h = 0; h < H_kv; ++h) {
1761 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1762 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1763 float *k_h = k + (size_t)h * kv_head_stride;
1764 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1769 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1770 for (
int h = 0; h < H_kv; ++h) {
1771 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1772 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1773 float *v_h = v + (size_t)h * kv_head_stride;
1774 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1789 aligned_context_window);
1801 aligned_context_window);
1804 const int K = H * aligned_head_dim;
1805 if (K != aligned_embed_dim) {
1808 const float *proj_in = attn_out;
1810 if (!proj_scratch) {
1813 for (
int t = 0; t < num_tokens; ++t) {
1814 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1815 for (
int h = 0; h < H; ++h) {
1816 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1817 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1819 (
size_t)aligned_head_dim *
sizeof(
float));
1822 proj_in = proj_scratch;
1824 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
1840 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
1841 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
1842 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
1854 int aligned_embed_dim,
1855 int aligned_head_dim,
1856 int aligned_intermediate_dim,
1857 int aligned_context_window
1859 const MODELLayerOffsets *L = &MODEL_LAYERS[10];
1861 float *input = MODEL_PTR(model, MODEL_LAYERS[9].output);
1862 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
1863 float *ln1_out = MODEL_PTR(model, L->ln1_out);
1864 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
1865 float *ln2_out = MODEL_PTR(model, L->ln2_out);
1866 float *q = MODEL_PTR(model, L->q);
1867 float *k = MODEL_PTR(model, L->k);
1868 float *v = MODEL_PTR(model, L->v);
1869 float *attn_out = MODEL_PTR(model, L->attn_out);
1870 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
1871 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
1872 float *residual1 = MODEL_PTR(model, L->residual1);
1873 float *fc1_out = MODEL_PTR(model, L->fc1_out);
1874 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
1875 float *mlp_out = MODEL_PTR(model, L->mlp_out);
1876 float *output = MODEL_PTR(model, L->output);
1878 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
1879 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
1880 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
1881 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
1882 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
1883 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
1884 const float *BQ = NULL;
1885 const float *BK = NULL;
1886 const float *BV = NULL;
1887 const float *BO = NULL;
1888 const float *B1 = NULL;
1889 const float *B2 = NULL;
1891 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
1892 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
1897 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1898 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
1899 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
1913 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
1914 for (
int h = 0; h < H; ++h) {
1915 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
1916 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1917 float *q_h = q + (size_t)h * q_head_stride;
1918 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1923 const uint8_t *WK_bytes = (
const uint8_t *)WK;
1924 for (
int h = 0; h < H_kv; ++h) {
1925 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
1926 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1927 float *k_h = k + (size_t)h * kv_head_stride;
1928 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1933 const uint8_t *WV_bytes = (
const uint8_t *)WV;
1934 for (
int h = 0; h < H_kv; ++h) {
1935 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
1936 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1937 float *v_h = v + (size_t)h * kv_head_stride;
1938 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
1953 aligned_context_window);
1965 aligned_context_window);
1968 const int K = H * aligned_head_dim;
1969 if (K != aligned_embed_dim) {
1972 const float *proj_in = attn_out;
1974 if (!proj_scratch) {
1977 for (
int t = 0; t < num_tokens; ++t) {
1978 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
1979 for (
int h = 0; h < H; ++h) {
1980 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1981 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
1983 (
size_t)aligned_head_dim *
sizeof(
float));
1986 proj_in = proj_scratch;
1988 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2004 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2005 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2006 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2018 int aligned_embed_dim,
2019 int aligned_head_dim,
2020 int aligned_intermediate_dim,
2021 int aligned_context_window
2023 const MODELLayerOffsets *L = &MODEL_LAYERS[11];
2025 float *input = MODEL_PTR(model, MODEL_LAYERS[10].output);
2026 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2027 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2028 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2029 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2030 float *q = MODEL_PTR(model, L->q);
2031 float *k = MODEL_PTR(model, L->k);
2032 float *v = MODEL_PTR(model, L->v);
2033 float *attn_out = MODEL_PTR(model, L->attn_out);
2034 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2035 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2036 float *residual1 = MODEL_PTR(model, L->residual1);
2037 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2038 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2039 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2040 float *output = MODEL_PTR(model, L->output);
2042 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2043 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2044 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2045 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2046 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2047 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2048 const float *BQ = NULL;
2049 const float *BK = NULL;
2050 const float *BV = NULL;
2051 const float *BO = NULL;
2052 const float *B1 = NULL;
2053 const float *B2 = NULL;
2055 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2056 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2061 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2062 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2063 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2077 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2078 for (
int h = 0; h < H; ++h) {
2079 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2080 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2081 float *q_h = q + (size_t)h * q_head_stride;
2082 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2087 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2088 for (
int h = 0; h < H_kv; ++h) {
2089 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2090 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2091 float *k_h = k + (size_t)h * kv_head_stride;
2092 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2097 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2098 for (
int h = 0; h < H_kv; ++h) {
2099 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2100 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2101 float *v_h = v + (size_t)h * kv_head_stride;
2102 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2117 aligned_context_window);
2129 aligned_context_window);
2132 const int K = H * aligned_head_dim;
2133 if (K != aligned_embed_dim) {
2136 const float *proj_in = attn_out;
2138 if (!proj_scratch) {
2141 for (
int t = 0; t < num_tokens; ++t) {
2142 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2143 for (
int h = 0; h < H; ++h) {
2144 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2145 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2147 (
size_t)aligned_head_dim *
sizeof(
float));
2150 proj_in = proj_scratch;
2152 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2168 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2169 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2170 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2182 int aligned_embed_dim,
2183 int aligned_head_dim,
2184 int aligned_intermediate_dim,
2185 int aligned_context_window
2187 const MODELLayerOffsets *L = &MODEL_LAYERS[12];
2189 float *input = MODEL_PTR(model, MODEL_LAYERS[11].output);
2190 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2191 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2192 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2193 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2194 float *q = MODEL_PTR(model, L->q);
2195 float *k = MODEL_PTR(model, L->k);
2196 float *v = MODEL_PTR(model, L->v);
2197 float *attn_out = MODEL_PTR(model, L->attn_out);
2198 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2199 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2200 float *residual1 = MODEL_PTR(model, L->residual1);
2201 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2202 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2203 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2204 float *output = MODEL_PTR(model, L->output);
2206 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2207 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2208 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2209 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2210 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2211 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2212 const float *BQ = NULL;
2213 const float *BK = NULL;
2214 const float *BV = NULL;
2215 const float *BO = NULL;
2216 const float *B1 = NULL;
2217 const float *B2 = NULL;
2219 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2220 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2225 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2226 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2227 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2241 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2242 for (
int h = 0; h < H; ++h) {
2243 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2244 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2245 float *q_h = q + (size_t)h * q_head_stride;
2246 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2251 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2252 for (
int h = 0; h < H_kv; ++h) {
2253 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2254 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2255 float *k_h = k + (size_t)h * kv_head_stride;
2256 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2261 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2262 for (
int h = 0; h < H_kv; ++h) {
2263 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2264 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2265 float *v_h = v + (size_t)h * kv_head_stride;
2266 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2281 aligned_context_window);
2293 aligned_context_window);
2296 const int K = H * aligned_head_dim;
2297 if (K != aligned_embed_dim) {
2300 const float *proj_in = attn_out;
2302 if (!proj_scratch) {
2305 for (
int t = 0; t < num_tokens; ++t) {
2306 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2307 for (
int h = 0; h < H; ++h) {
2308 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2309 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2311 (
size_t)aligned_head_dim *
sizeof(
float));
2314 proj_in = proj_scratch;
2316 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2332 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2333 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2334 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2346 int aligned_embed_dim,
2347 int aligned_head_dim,
2348 int aligned_intermediate_dim,
2349 int aligned_context_window
2351 const MODELLayerOffsets *L = &MODEL_LAYERS[13];
2353 float *input = MODEL_PTR(model, MODEL_LAYERS[12].output);
2354 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2355 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2356 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2357 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2358 float *q = MODEL_PTR(model, L->q);
2359 float *k = MODEL_PTR(model, L->k);
2360 float *v = MODEL_PTR(model, L->v);
2361 float *attn_out = MODEL_PTR(model, L->attn_out);
2362 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2363 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2364 float *residual1 = MODEL_PTR(model, L->residual1);
2365 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2366 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2367 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2368 float *output = MODEL_PTR(model, L->output);
2370 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2371 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2372 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2373 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2374 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2375 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2376 const float *BQ = NULL;
2377 const float *BK = NULL;
2378 const float *BV = NULL;
2379 const float *BO = NULL;
2380 const float *B1 = NULL;
2381 const float *B2 = NULL;
2383 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2384 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2389 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2390 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2391 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2405 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2406 for (
int h = 0; h < H; ++h) {
2407 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2408 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2409 float *q_h = q + (size_t)h * q_head_stride;
2410 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2415 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2416 for (
int h = 0; h < H_kv; ++h) {
2417 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2418 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2419 float *k_h = k + (size_t)h * kv_head_stride;
2420 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2425 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2426 for (
int h = 0; h < H_kv; ++h) {
2427 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2428 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2429 float *v_h = v + (size_t)h * kv_head_stride;
2430 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2445 aligned_context_window);
2457 aligned_context_window);
2460 const int K = H * aligned_head_dim;
2461 if (K != aligned_embed_dim) {
2464 const float *proj_in = attn_out;
2466 if (!proj_scratch) {
2469 for (
int t = 0; t < num_tokens; ++t) {
2470 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2471 for (
int h = 0; h < H; ++h) {
2472 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2473 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2475 (
size_t)aligned_head_dim *
sizeof(
float));
2478 proj_in = proj_scratch;
2480 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2496 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2497 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2498 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2510 int aligned_embed_dim,
2511 int aligned_head_dim,
2512 int aligned_intermediate_dim,
2513 int aligned_context_window
2515 const MODELLayerOffsets *L = &MODEL_LAYERS[14];
2517 float *input = MODEL_PTR(model, MODEL_LAYERS[13].output);
2518 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2519 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2520 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2521 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2522 float *q = MODEL_PTR(model, L->q);
2523 float *k = MODEL_PTR(model, L->k);
2524 float *v = MODEL_PTR(model, L->v);
2525 float *attn_out = MODEL_PTR(model, L->attn_out);
2526 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2527 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2528 float *residual1 = MODEL_PTR(model, L->residual1);
2529 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2530 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2531 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2532 float *output = MODEL_PTR(model, L->output);
2534 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2535 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2536 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2537 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2538 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2539 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2540 const float *BQ = NULL;
2541 const float *BK = NULL;
2542 const float *BV = NULL;
2543 const float *BO = NULL;
2544 const float *B1 = NULL;
2545 const float *B2 = NULL;
2547 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2548 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2553 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2554 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2555 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2569 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2570 for (
int h = 0; h < H; ++h) {
2571 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2572 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2573 float *q_h = q + (size_t)h * q_head_stride;
2574 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2579 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2580 for (
int h = 0; h < H_kv; ++h) {
2581 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2582 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2583 float *k_h = k + (size_t)h * kv_head_stride;
2584 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2589 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2590 for (
int h = 0; h < H_kv; ++h) {
2591 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2592 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2593 float *v_h = v + (size_t)h * kv_head_stride;
2594 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2609 aligned_context_window);
2621 aligned_context_window);
2624 const int K = H * aligned_head_dim;
2625 if (K != aligned_embed_dim) {
2628 const float *proj_in = attn_out;
2630 if (!proj_scratch) {
2633 for (
int t = 0; t < num_tokens; ++t) {
2634 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2635 for (
int h = 0; h < H; ++h) {
2636 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2637 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2639 (
size_t)aligned_head_dim *
sizeof(
float));
2642 proj_in = proj_scratch;
2644 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2660 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2661 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2662 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2674 int aligned_embed_dim,
2675 int aligned_head_dim,
2676 int aligned_intermediate_dim,
2677 int aligned_context_window
2679 const MODELLayerOffsets *L = &MODEL_LAYERS[15];
2681 float *input = MODEL_PTR(model, MODEL_LAYERS[14].output);
2682 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2683 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2684 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2685 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2686 float *q = MODEL_PTR(model, L->q);
2687 float *k = MODEL_PTR(model, L->k);
2688 float *v = MODEL_PTR(model, L->v);
2689 float *attn_out = MODEL_PTR(model, L->attn_out);
2690 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2691 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2692 float *residual1 = MODEL_PTR(model, L->residual1);
2693 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2694 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2695 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2696 float *output = MODEL_PTR(model, L->output);
2698 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2699 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2700 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2701 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2702 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2703 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2704 const float *BQ = NULL;
2705 const float *BK = NULL;
2706 const float *BV = NULL;
2707 const float *BO = NULL;
2708 const float *B1 = NULL;
2709 const float *B2 = NULL;
2711 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2712 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2717 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2718 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2719 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2733 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2734 for (
int h = 0; h < H; ++h) {
2735 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2736 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2737 float *q_h = q + (size_t)h * q_head_stride;
2738 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2743 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2744 for (
int h = 0; h < H_kv; ++h) {
2745 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2746 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2747 float *k_h = k + (size_t)h * kv_head_stride;
2748 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2753 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2754 for (
int h = 0; h < H_kv; ++h) {
2755 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2756 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2757 float *v_h = v + (size_t)h * kv_head_stride;
2758 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2773 aligned_context_window);
2785 aligned_context_window);
2788 const int K = H * aligned_head_dim;
2789 if (K != aligned_embed_dim) {
2792 const float *proj_in = attn_out;
2794 if (!proj_scratch) {
2797 for (
int t = 0; t < num_tokens; ++t) {
2798 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2799 for (
int h = 0; h < H; ++h) {
2800 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2801 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2803 (
size_t)aligned_head_dim *
sizeof(
float));
2806 proj_in = proj_scratch;
2808 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2824 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2825 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2826 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
2838 int aligned_embed_dim,
2839 int aligned_head_dim,
2840 int aligned_intermediate_dim,
2841 int aligned_context_window
2843 const MODELLayerOffsets *L = &MODEL_LAYERS[16];
2845 float *input = MODEL_PTR(model, MODEL_LAYERS[15].output);
2846 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
2847 float *ln1_out = MODEL_PTR(model, L->ln1_out);
2848 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
2849 float *ln2_out = MODEL_PTR(model, L->ln2_out);
2850 float *q = MODEL_PTR(model, L->q);
2851 float *k = MODEL_PTR(model, L->k);
2852 float *v = MODEL_PTR(model, L->v);
2853 float *attn_out = MODEL_PTR(model, L->attn_out);
2854 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
2855 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
2856 float *residual1 = MODEL_PTR(model, L->residual1);
2857 float *fc1_out = MODEL_PTR(model, L->fc1_out);
2858 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
2859 float *mlp_out = MODEL_PTR(model, L->mlp_out);
2860 float *output = MODEL_PTR(model, L->output);
2862 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
2863 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
2864 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
2865 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
2866 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
2867 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
2868 const float *BQ = NULL;
2869 const float *BK = NULL;
2870 const float *BV = NULL;
2871 const float *BO = NULL;
2872 const float *B1 = NULL;
2873 const float *B2 = NULL;
2875 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
2876 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
2881 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
2882 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
2883 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
2897 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
2898 for (
int h = 0; h < H; ++h) {
2899 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
2900 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2901 float *q_h = q + (size_t)h * q_head_stride;
2902 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2907 const uint8_t *WK_bytes = (
const uint8_t *)WK;
2908 for (
int h = 0; h < H_kv; ++h) {
2909 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
2910 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2911 float *k_h = k + (size_t)h * kv_head_stride;
2912 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2917 const uint8_t *WV_bytes = (
const uint8_t *)WV;
2918 for (
int h = 0; h < H_kv; ++h) {
2919 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
2920 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
2921 float *v_h = v + (size_t)h * kv_head_stride;
2922 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
2937 aligned_context_window);
2949 aligned_context_window);
2952 const int K = H * aligned_head_dim;
2953 if (K != aligned_embed_dim) {
2956 const float *proj_in = attn_out;
2958 if (!proj_scratch) {
2961 for (
int t = 0; t < num_tokens; ++t) {
2962 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
2963 for (
int h = 0; h < H; ++h) {
2964 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
2965 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
2967 (
size_t)aligned_head_dim *
sizeof(
float));
2970 proj_in = proj_scratch;
2972 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
2988 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
2989 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
2990 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3002 int aligned_embed_dim,
3003 int aligned_head_dim,
3004 int aligned_intermediate_dim,
3005 int aligned_context_window
3007 const MODELLayerOffsets *L = &MODEL_LAYERS[17];
3009 float *input = MODEL_PTR(model, MODEL_LAYERS[16].output);
3010 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3011 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3012 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3013 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3014 float *q = MODEL_PTR(model, L->q);
3015 float *k = MODEL_PTR(model, L->k);
3016 float *v = MODEL_PTR(model, L->v);
3017 float *attn_out = MODEL_PTR(model, L->attn_out);
3018 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3019 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3020 float *residual1 = MODEL_PTR(model, L->residual1);
3021 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3022 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3023 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3024 float *output = MODEL_PTR(model, L->output);
3026 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3027 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3028 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3029 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3030 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3031 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3032 const float *BQ = NULL;
3033 const float *BK = NULL;
3034 const float *BV = NULL;
3035 const float *BO = NULL;
3036 const float *B1 = NULL;
3037 const float *B2 = NULL;
3039 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3040 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3045 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3046 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3047 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3061 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3062 for (
int h = 0; h < H; ++h) {
3063 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3064 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3065 float *q_h = q + (size_t)h * q_head_stride;
3066 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3071 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3072 for (
int h = 0; h < H_kv; ++h) {
3073 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3074 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3075 float *k_h = k + (size_t)h * kv_head_stride;
3076 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3081 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3082 for (
int h = 0; h < H_kv; ++h) {
3083 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3084 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3085 float *v_h = v + (size_t)h * kv_head_stride;
3086 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3101 aligned_context_window);
3113 aligned_context_window);
3116 const int K = H * aligned_head_dim;
3117 if (K != aligned_embed_dim) {
3120 const float *proj_in = attn_out;
3122 if (!proj_scratch) {
3125 for (
int t = 0; t < num_tokens; ++t) {
3126 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3127 for (
int h = 0; h < H; ++h) {
3128 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3129 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3131 (
size_t)aligned_head_dim *
sizeof(
float));
3134 proj_in = proj_scratch;
3136 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3152 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3153 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3154 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3166 int aligned_embed_dim,
3167 int aligned_head_dim,
3168 int aligned_intermediate_dim,
3169 int aligned_context_window
3171 const MODELLayerOffsets *L = &MODEL_LAYERS[18];
3173 float *input = MODEL_PTR(model, MODEL_LAYERS[17].output);
3174 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3175 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3176 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3177 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3178 float *q = MODEL_PTR(model, L->q);
3179 float *k = MODEL_PTR(model, L->k);
3180 float *v = MODEL_PTR(model, L->v);
3181 float *attn_out = MODEL_PTR(model, L->attn_out);
3182 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3183 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3184 float *residual1 = MODEL_PTR(model, L->residual1);
3185 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3186 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3187 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3188 float *output = MODEL_PTR(model, L->output);
3190 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3191 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3192 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3193 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3194 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3195 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3196 const float *BQ = NULL;
3197 const float *BK = NULL;
3198 const float *BV = NULL;
3199 const float *BO = NULL;
3200 const float *B1 = NULL;
3201 const float *B2 = NULL;
3203 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3204 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3209 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3210 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3211 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3225 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3226 for (
int h = 0; h < H; ++h) {
3227 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3228 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3229 float *q_h = q + (size_t)h * q_head_stride;
3230 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3235 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3236 for (
int h = 0; h < H_kv; ++h) {
3237 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3238 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3239 float *k_h = k + (size_t)h * kv_head_stride;
3240 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3245 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3246 for (
int h = 0; h < H_kv; ++h) {
3247 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3248 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3249 float *v_h = v + (size_t)h * kv_head_stride;
3250 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3265 aligned_context_window);
3277 aligned_context_window);
3280 const int K = H * aligned_head_dim;
3281 if (K != aligned_embed_dim) {
3284 const float *proj_in = attn_out;
3286 if (!proj_scratch) {
3289 for (
int t = 0; t < num_tokens; ++t) {
3290 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3291 for (
int h = 0; h < H; ++h) {
3292 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3293 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3295 (
size_t)aligned_head_dim *
sizeof(
float));
3298 proj_in = proj_scratch;
3300 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3316 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3317 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3318 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3330 int aligned_embed_dim,
3331 int aligned_head_dim,
3332 int aligned_intermediate_dim,
3333 int aligned_context_window
3335 const MODELLayerOffsets *L = &MODEL_LAYERS[19];
3337 float *input = MODEL_PTR(model, MODEL_LAYERS[18].output);
3338 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3339 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3340 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3341 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3342 float *q = MODEL_PTR(model, L->q);
3343 float *k = MODEL_PTR(model, L->k);
3344 float *v = MODEL_PTR(model, L->v);
3345 float *attn_out = MODEL_PTR(model, L->attn_out);
3346 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3347 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3348 float *residual1 = MODEL_PTR(model, L->residual1);
3349 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3350 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3351 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3352 float *output = MODEL_PTR(model, L->output);
3354 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3355 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3356 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3357 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3358 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3359 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3360 const float *BQ = NULL;
3361 const float *BK = NULL;
3362 const float *BV = NULL;
3363 const float *BO = NULL;
3364 const float *B1 = NULL;
3365 const float *B2 = NULL;
3367 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3368 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3373 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3374 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3375 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3389 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3390 for (
int h = 0; h < H; ++h) {
3391 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3392 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3393 float *q_h = q + (size_t)h * q_head_stride;
3394 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3399 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3400 for (
int h = 0; h < H_kv; ++h) {
3401 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3402 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3403 float *k_h = k + (size_t)h * kv_head_stride;
3404 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3409 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3410 for (
int h = 0; h < H_kv; ++h) {
3411 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3412 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3413 float *v_h = v + (size_t)h * kv_head_stride;
3414 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3429 aligned_context_window);
3441 aligned_context_window);
3444 const int K = H * aligned_head_dim;
3445 if (K != aligned_embed_dim) {
3448 const float *proj_in = attn_out;
3450 if (!proj_scratch) {
3453 for (
int t = 0; t < num_tokens; ++t) {
3454 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3455 for (
int h = 0; h < H; ++h) {
3456 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3457 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3459 (
size_t)aligned_head_dim *
sizeof(
float));
3462 proj_in = proj_scratch;
3464 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3480 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3481 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3482 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3494 int aligned_embed_dim,
3495 int aligned_head_dim,
3496 int aligned_intermediate_dim,
3497 int aligned_context_window
3499 const MODELLayerOffsets *L = &MODEL_LAYERS[20];
3501 float *input = MODEL_PTR(model, MODEL_LAYERS[19].output);
3502 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3503 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3504 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3505 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3506 float *q = MODEL_PTR(model, L->q);
3507 float *k = MODEL_PTR(model, L->k);
3508 float *v = MODEL_PTR(model, L->v);
3509 float *attn_out = MODEL_PTR(model, L->attn_out);
3510 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3511 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3512 float *residual1 = MODEL_PTR(model, L->residual1);
3513 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3514 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3515 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3516 float *output = MODEL_PTR(model, L->output);
3518 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3519 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3520 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3521 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3522 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3523 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3524 const float *BQ = NULL;
3525 const float *BK = NULL;
3526 const float *BV = NULL;
3527 const float *BO = NULL;
3528 const float *B1 = NULL;
3529 const float *B2 = NULL;
3531 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3532 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3537 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3538 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3539 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3553 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3554 for (
int h = 0; h < H; ++h) {
3555 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3556 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3557 float *q_h = q + (size_t)h * q_head_stride;
3558 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3563 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3564 for (
int h = 0; h < H_kv; ++h) {
3565 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3566 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3567 float *k_h = k + (size_t)h * kv_head_stride;
3568 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3573 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3574 for (
int h = 0; h < H_kv; ++h) {
3575 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3576 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3577 float *v_h = v + (size_t)h * kv_head_stride;
3578 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3593 aligned_context_window);
3605 aligned_context_window);
3608 const int K = H * aligned_head_dim;
3609 if (K != aligned_embed_dim) {
3612 const float *proj_in = attn_out;
3614 if (!proj_scratch) {
3617 for (
int t = 0; t < num_tokens; ++t) {
3618 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3619 for (
int h = 0; h < H; ++h) {
3620 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3621 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3623 (
size_t)aligned_head_dim *
sizeof(
float));
3626 proj_in = proj_scratch;
3628 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3644 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3645 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3646 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3658 int aligned_embed_dim,
3659 int aligned_head_dim,
3660 int aligned_intermediate_dim,
3661 int aligned_context_window
3663 const MODELLayerOffsets *L = &MODEL_LAYERS[21];
3665 float *input = MODEL_PTR(model, MODEL_LAYERS[20].output);
3666 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3667 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3668 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3669 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3670 float *q = MODEL_PTR(model, L->q);
3671 float *k = MODEL_PTR(model, L->k);
3672 float *v = MODEL_PTR(model, L->v);
3673 float *attn_out = MODEL_PTR(model, L->attn_out);
3674 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3675 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3676 float *residual1 = MODEL_PTR(model, L->residual1);
3677 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3678 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3679 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3680 float *output = MODEL_PTR(model, L->output);
3682 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3683 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3684 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3685 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3686 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3687 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3688 const float *BQ = NULL;
3689 const float *BK = NULL;
3690 const float *BV = NULL;
3691 const float *BO = NULL;
3692 const float *B1 = NULL;
3693 const float *B2 = NULL;
3695 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3696 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3701 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3702 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3703 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3717 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3718 for (
int h = 0; h < H; ++h) {
3719 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3720 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3721 float *q_h = q + (size_t)h * q_head_stride;
3722 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3727 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3728 for (
int h = 0; h < H_kv; ++h) {
3729 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3730 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3731 float *k_h = k + (size_t)h * kv_head_stride;
3732 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3737 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3738 for (
int h = 0; h < H_kv; ++h) {
3739 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3740 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3741 float *v_h = v + (size_t)h * kv_head_stride;
3742 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3757 aligned_context_window);
3769 aligned_context_window);
3772 const int K = H * aligned_head_dim;
3773 if (K != aligned_embed_dim) {
3776 const float *proj_in = attn_out;
3778 if (!proj_scratch) {
3781 for (
int t = 0; t < num_tokens; ++t) {
3782 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3783 for (
int h = 0; h < H; ++h) {
3784 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3785 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3787 (
size_t)aligned_head_dim *
sizeof(
float));
3790 proj_in = proj_scratch;
3792 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3808 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3809 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3810 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3822 int aligned_embed_dim,
3823 int aligned_head_dim,
3824 int aligned_intermediate_dim,
3825 int aligned_context_window
3827 const MODELLayerOffsets *L = &MODEL_LAYERS[22];
3829 float *input = MODEL_PTR(model, MODEL_LAYERS[21].output);
3830 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3831 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3832 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3833 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3834 float *q = MODEL_PTR(model, L->q);
3835 float *k = MODEL_PTR(model, L->k);
3836 float *v = MODEL_PTR(model, L->v);
3837 float *attn_out = MODEL_PTR(model, L->attn_out);
3838 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
3839 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
3840 float *residual1 = MODEL_PTR(model, L->residual1);
3841 float *fc1_out = MODEL_PTR(model, L->fc1_out);
3842 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
3843 float *mlp_out = MODEL_PTR(model, L->mlp_out);
3844 float *output = MODEL_PTR(model, L->output);
3846 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
3847 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
3848 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
3849 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
3850 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
3851 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
3852 const float *BQ = NULL;
3853 const float *BK = NULL;
3854 const float *BV = NULL;
3855 const float *BO = NULL;
3856 const float *B1 = NULL;
3857 const float *B2 = NULL;
3859 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
3860 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
3865 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
3866 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
3867 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
3881 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
3882 for (
int h = 0; h < H; ++h) {
3883 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
3884 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3885 float *q_h = q + (size_t)h * q_head_stride;
3886 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3891 const uint8_t *WK_bytes = (
const uint8_t *)WK;
3892 for (
int h = 0; h < H_kv; ++h) {
3893 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
3894 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3895 float *k_h = k + (size_t)h * kv_head_stride;
3896 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3901 const uint8_t *WV_bytes = (
const uint8_t *)WV;
3902 for (
int h = 0; h < H_kv; ++h) {
3903 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
3904 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
3905 float *v_h = v + (size_t)h * kv_head_stride;
3906 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
3921 aligned_context_window);
3933 aligned_context_window);
3936 const int K = H * aligned_head_dim;
3937 if (K != aligned_embed_dim) {
3940 const float *proj_in = attn_out;
3942 if (!proj_scratch) {
3945 for (
int t = 0; t < num_tokens; ++t) {
3946 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
3947 for (
int h = 0; h < H; ++h) {
3948 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
3949 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
3951 (
size_t)aligned_head_dim *
sizeof(
float));
3954 proj_in = proj_scratch;
3956 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
3972 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
3973 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
3974 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
3986 int aligned_embed_dim,
3987 int aligned_head_dim,
3988 int aligned_intermediate_dim,
3989 int aligned_context_window
3991 const MODELLayerOffsets *L = &MODEL_LAYERS[23];
3993 float *input = MODEL_PTR(model, MODEL_LAYERS[22].output);
3994 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
3995 float *ln1_out = MODEL_PTR(model, L->ln1_out);
3996 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
3997 float *ln2_out = MODEL_PTR(model, L->ln2_out);
3998 float *q = MODEL_PTR(model, L->q);
3999 float *k = MODEL_PTR(model, L->k);
4000 float *v = MODEL_PTR(model, L->v);
4001 float *attn_out = MODEL_PTR(model, L->attn_out);
4002 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4003 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4004 float *residual1 = MODEL_PTR(model, L->residual1);
4005 float *fc1_out = MODEL_PTR(model, L->fc1_out);
4006 float *swiglu_out = MODEL_PTR(model, L->swiglu_out);
4007 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4008 float *output = MODEL_PTR(model, L->output);
4010 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4011 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4012 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4013 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4014 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4015 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4016 const float *BQ = NULL;
4017 const float *BK = NULL;
4018 const float *BV = NULL;
4019 const float *BO = NULL;
4020 const float *B1 = NULL;
4021 const float *B2 = NULL;
4023 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4024 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4029 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4030 const size_t q_head_stride = (size_t)num_tokens * (
size_t)aligned_head_dim;
4031 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4045 const uint8_t *WQ_bytes = (
const uint8_t *)WQ;
4046 for (
int h = 0; h < H; ++h) {
4047 const void *wq_h = (
const void *)(WQ_bytes + (
size_t)h * wq_head_bytes);
4048 const float *bq_h = BQ ? (BQ + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4049 float *q_h = q + (size_t)h * q_head_stride;
4050 gemm_nt_q4_k(ln1_out, wq_h, bq_h, q_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4055 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4056 for (
int h = 0; h < H_kv; ++h) {
4057 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4058 const float *bk_h = BK ? (BK + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4059 float *k_h = k + (size_t)h * kv_head_stride;
4060 gemm_nt_q4_k(ln1_out, wk_h, bk_h, k_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4065 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4066 for (
int h = 0; h < H_kv; ++h) {
4067 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4068 const float *bv_h = BV ? (BV + (size_t)h * (
size_t)aligned_head_dim) : NULL;
4069 float *v_h = v + (size_t)h * kv_head_stride;
4070 gemm_nt_q4_k(ln1_out, wv_h, bv_h, v_h, num_tokens, aligned_head_dim, aligned_embed_dim);
4085 aligned_context_window);
4097 aligned_context_window);
4100 const int K = H * aligned_head_dim;
4101 if (K != aligned_embed_dim) {
4104 const float *proj_in = attn_out;
4106 if (!proj_scratch) {
4109 for (
int t = 0; t < num_tokens; ++t) {
4110 float *dst = proj_scratch + (size_t)t * (
size_t)aligned_embed_dim;
4111 for (
int h = 0; h < H; ++h) {
4112 const float *src = attn_out + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
4113 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
4115 (
size_t)aligned_head_dim *
sizeof(
float));
4118 proj_in = proj_scratch;
4120 gemm_nt_q4_k(proj_in, WO, BO, proj_tmp, num_tokens, aligned_embed_dim, K);
4136 gemm_nt_q4_k(ln2_out, W1, B1, fc1_out, num_tokens, 2 * aligned_intermediate_dim, aligned_embed_dim);
4137 swiglu_forward(fc1_out, swiglu_out, num_tokens, aligned_intermediate_dim);
4138 gemm_nt_q4_k(swiglu_out, W2, B2, mlp_out, num_tokens, aligned_embed_dim, aligned_intermediate_dim);
4153 if (!model || !tokens || num_tokens <= 0) {
4157 const int elem_bytes = MODEL_DTYPE_BYTES;
4158 const int aligned_embed_dim = 1024;
4159 const int aligned_head_dim = 64;
4160 const int aligned_intermediate_dim = 4864;
4161 const int aligned_context_window = 131072;
4163 float *embed_out = MODEL_PTR(model, MODEL_HEADER.embedded_input);
4164 const void *embed_weight = (
const void *)MODEL_PTR(model, MODEL_HEADER.token_emb);
4181 aligned_intermediate_dim,
4182 aligned_context_window);
4189 aligned_intermediate_dim,
4190 aligned_context_window);
4197 aligned_intermediate_dim,
4198 aligned_context_window);
4205 aligned_intermediate_dim,
4206 aligned_context_window);
4213 aligned_intermediate_dim,
4214 aligned_context_window);
4221 aligned_intermediate_dim,
4222 aligned_context_window);
4229 aligned_intermediate_dim,
4230 aligned_context_window);
4237 aligned_intermediate_dim,
4238 aligned_context_window);
4245 aligned_intermediate_dim,
4246 aligned_context_window);
4253 aligned_intermediate_dim,
4254 aligned_context_window);
4261 aligned_intermediate_dim,
4262 aligned_context_window);
4269 aligned_intermediate_dim,
4270 aligned_context_window);
4277 aligned_intermediate_dim,
4278 aligned_context_window);
4285 aligned_intermediate_dim,
4286 aligned_context_window);
4293 aligned_intermediate_dim,
4294 aligned_context_window);
4301 aligned_intermediate_dim,
4302 aligned_context_window);
4309 aligned_intermediate_dim,
4310 aligned_context_window);
4317 aligned_intermediate_dim,
4318 aligned_context_window);
4325 aligned_intermediate_dim,
4326 aligned_context_window);
4333 aligned_intermediate_dim,
4334 aligned_context_window);
4341 aligned_intermediate_dim,
4342 aligned_context_window);
4349 aligned_intermediate_dim,
4350 aligned_context_window);
4357 aligned_intermediate_dim,
4358 aligned_context_window);
4365 aligned_intermediate_dim,
4366 aligned_context_window);
4368 float *last_hidden = MODEL_PTR(model, MODEL_LAYERS[
MODEL_NUM_LAYERS - 1].output);
4369 float *final_ln_weight = MODEL_PTR(model, MODEL_FOOTER.final_ln_weight);
4370 float *final_out = MODEL_PTR(model, MODEL_FOOTER.final_output);
4380 float *logits = MODEL_PTR(model, MODEL_FOOTER.logits);
4381 const void *lm_head = (
const void *)MODEL_PTR(model, MODEL_FOOTER.lm_head_weight);
4383 for (
int t = 0; t < num_tokens; ++t) {
4384 uint8_t q8_buf[q8_bytes];
4385 const float *row = final_out + (size_t)t * (
size_t)aligned_embed_dim;
4408 int aligned_embed_dim,
4409 int aligned_head_dim,
4410 int aligned_intermediate_dim,
4411 int aligned_context_window
4413 const MODELLayerOffsets *L = &MODEL_LAYERS[0];
4415 float *input = MODEL_PTR(model, MODEL_HEADER.embedded_input);
4417 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4418 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4419 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4420 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4421 float *k_cache = MODEL_PTR(model, L->k);
4422 float *v_cache = MODEL_PTR(model, L->v);
4423 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4424 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4425 float *residual1 = MODEL_PTR(model, L->residual1);
4426 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4427 float *output = MODEL_PTR(model, L->output);
4430 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4431 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4432 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4433 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4434 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4435 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4437 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4438 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4444 float q_token[H * aligned_head_dim];
4445 float k_token[H_kv * aligned_head_dim];
4446 float v_token[H_kv * aligned_head_dim];
4447 float attn_token[H * aligned_head_dim];
4450 float fc1_out[2 * aligned_intermediate_dim];
4451 float swiglu_out[aligned_intermediate_dim];
4463 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4467 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4468 uint8_t ln1_q8[ln1_q8_bytes];
4470 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4471 if (aligned_head_dim > head_dim) {
4472 for (
int h = 0; h < H; ++h) {
4473 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4474 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4481 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4483 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4485 for (
int h = 0; h < H_kv; ++h) {
4486 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4487 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4488 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4489 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4495 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4497 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4499 for (
int h = 0; h < H_kv; ++h) {
4500 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4501 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4502 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4503 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4517 for (
int h = 0; h < H_kv; ++h) {
4518 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4539 aligned_context_window,
4545 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4546 uint8_t attn_q8[attn_q8_bytes];
4548 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4565 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4566 uint8_t ln2_q8[ln2_q8_bytes];
4568 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4571 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4574 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4575 uint8_t swiglu_q8[swiglu_q8_bytes];
4577 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4589 int aligned_embed_dim,
4590 int aligned_head_dim,
4591 int aligned_intermediate_dim,
4592 int aligned_context_window
4594 const MODELLayerOffsets *L = &MODEL_LAYERS[1];
4596 float *input = MODEL_PTR(model, MODEL_LAYERS[0].output);
4598 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4599 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4600 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4601 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4602 float *k_cache = MODEL_PTR(model, L->k);
4603 float *v_cache = MODEL_PTR(model, L->v);
4604 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4605 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4606 float *residual1 = MODEL_PTR(model, L->residual1);
4607 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4608 float *output = MODEL_PTR(model, L->output);
4611 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4612 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4613 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4614 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4615 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4616 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4618 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4619 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4625 float q_token[H * aligned_head_dim];
4626 float k_token[H_kv * aligned_head_dim];
4627 float v_token[H_kv * aligned_head_dim];
4628 float attn_token[H * aligned_head_dim];
4631 float fc1_out[2 * aligned_intermediate_dim];
4632 float swiglu_out[aligned_intermediate_dim];
4644 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4648 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4649 uint8_t ln1_q8[ln1_q8_bytes];
4651 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4652 if (aligned_head_dim > head_dim) {
4653 for (
int h = 0; h < H; ++h) {
4654 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4655 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4662 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4664 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4666 for (
int h = 0; h < H_kv; ++h) {
4667 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4668 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4669 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4670 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4676 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4678 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4680 for (
int h = 0; h < H_kv; ++h) {
4681 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4682 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4683 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4684 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4698 for (
int h = 0; h < H_kv; ++h) {
4699 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4720 aligned_context_window,
4726 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4727 uint8_t attn_q8[attn_q8_bytes];
4729 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4746 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4747 uint8_t ln2_q8[ln2_q8_bytes];
4749 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4752 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4755 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4756 uint8_t swiglu_q8[swiglu_q8_bytes];
4758 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4770 int aligned_embed_dim,
4771 int aligned_head_dim,
4772 int aligned_intermediate_dim,
4773 int aligned_context_window
4775 const MODELLayerOffsets *L = &MODEL_LAYERS[2];
4777 float *input = MODEL_PTR(model, MODEL_LAYERS[1].output);
4779 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4780 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4781 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4782 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4783 float *k_cache = MODEL_PTR(model, L->k);
4784 float *v_cache = MODEL_PTR(model, L->v);
4785 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4786 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4787 float *residual1 = MODEL_PTR(model, L->residual1);
4788 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4789 float *output = MODEL_PTR(model, L->output);
4792 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4793 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4794 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4795 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4796 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4797 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4799 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4800 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4806 float q_token[H * aligned_head_dim];
4807 float k_token[H_kv * aligned_head_dim];
4808 float v_token[H_kv * aligned_head_dim];
4809 float attn_token[H * aligned_head_dim];
4812 float fc1_out[2 * aligned_intermediate_dim];
4813 float swiglu_out[aligned_intermediate_dim];
4825 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
4829 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4830 uint8_t ln1_q8[ln1_q8_bytes];
4832 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
4833 if (aligned_head_dim > head_dim) {
4834 for (
int h = 0; h < H; ++h) {
4835 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
4836 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4843 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4845 const uint8_t *WK_bytes = (
const uint8_t *)WK;
4847 for (
int h = 0; h < H_kv; ++h) {
4848 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
4849 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4850 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
4851 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4857 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
4859 const uint8_t *WV_bytes = (
const uint8_t *)WV;
4861 for (
int h = 0; h < H_kv; ++h) {
4862 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
4863 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4864 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
4865 for (
int d = head_dim; d < aligned_head_dim; ++d) {
4879 for (
int h = 0; h < H_kv; ++h) {
4880 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
4901 aligned_context_window,
4907 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
4908 uint8_t attn_q8[attn_q8_bytes];
4910 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
4927 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
4928 uint8_t ln2_q8[ln2_q8_bytes];
4930 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
4933 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
4936 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
4937 uint8_t swiglu_q8[swiglu_q8_bytes];
4939 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
4951 int aligned_embed_dim,
4952 int aligned_head_dim,
4953 int aligned_intermediate_dim,
4954 int aligned_context_window
4956 const MODELLayerOffsets *L = &MODEL_LAYERS[3];
4958 float *input = MODEL_PTR(model, MODEL_LAYERS[2].output);
4960 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
4961 float *ln1_out = MODEL_PTR(model, L->ln1_out);
4962 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
4963 float *ln2_out = MODEL_PTR(model, L->ln2_out);
4964 float *k_cache = MODEL_PTR(model, L->k);
4965 float *v_cache = MODEL_PTR(model, L->v);
4966 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
4967 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
4968 float *residual1 = MODEL_PTR(model, L->residual1);
4969 float *mlp_out = MODEL_PTR(model, L->mlp_out);
4970 float *output = MODEL_PTR(model, L->output);
4973 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
4974 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
4975 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
4976 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
4977 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
4978 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
4980 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
4981 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
4987 float q_token[H * aligned_head_dim];
4988 float k_token[H_kv * aligned_head_dim];
4989 float v_token[H_kv * aligned_head_dim];
4990 float attn_token[H * aligned_head_dim];
4993 float fc1_out[2 * aligned_intermediate_dim];
4994 float swiglu_out[aligned_intermediate_dim];
5006 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5010 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5011 uint8_t ln1_q8[ln1_q8_bytes];
5013 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5014 if (aligned_head_dim > head_dim) {
5015 for (
int h = 0; h < H; ++h) {
5016 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5017 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5024 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5026 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5028 for (
int h = 0; h < H_kv; ++h) {
5029 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5030 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5031 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5032 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5038 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5040 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5042 for (
int h = 0; h < H_kv; ++h) {
5043 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5044 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5045 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5046 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5060 for (
int h = 0; h < H_kv; ++h) {
5061 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5082 aligned_context_window,
5088 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5089 uint8_t attn_q8[attn_q8_bytes];
5091 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5108 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5109 uint8_t ln2_q8[ln2_q8_bytes];
5111 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5114 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5117 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5118 uint8_t swiglu_q8[swiglu_q8_bytes];
5120 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5132 int aligned_embed_dim,
5133 int aligned_head_dim,
5134 int aligned_intermediate_dim,
5135 int aligned_context_window
5137 const MODELLayerOffsets *L = &MODEL_LAYERS[4];
5139 float *input = MODEL_PTR(model, MODEL_LAYERS[3].output);
5141 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5142 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5143 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5144 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5145 float *k_cache = MODEL_PTR(model, L->k);
5146 float *v_cache = MODEL_PTR(model, L->v);
5147 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5148 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5149 float *residual1 = MODEL_PTR(model, L->residual1);
5150 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5151 float *output = MODEL_PTR(model, L->output);
5154 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5155 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5156 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5157 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5158 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5159 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5161 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5162 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5168 float q_token[H * aligned_head_dim];
5169 float k_token[H_kv * aligned_head_dim];
5170 float v_token[H_kv * aligned_head_dim];
5171 float attn_token[H * aligned_head_dim];
5174 float fc1_out[2 * aligned_intermediate_dim];
5175 float swiglu_out[aligned_intermediate_dim];
5187 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5191 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5192 uint8_t ln1_q8[ln1_q8_bytes];
5194 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5195 if (aligned_head_dim > head_dim) {
5196 for (
int h = 0; h < H; ++h) {
5197 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5198 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5205 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5207 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5209 for (
int h = 0; h < H_kv; ++h) {
5210 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5211 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5212 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5213 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5219 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5221 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5223 for (
int h = 0; h < H_kv; ++h) {
5224 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5225 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5226 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5227 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5241 for (
int h = 0; h < H_kv; ++h) {
5242 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5263 aligned_context_window,
5269 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5270 uint8_t attn_q8[attn_q8_bytes];
5272 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5289 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5290 uint8_t ln2_q8[ln2_q8_bytes];
5292 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5295 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5298 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5299 uint8_t swiglu_q8[swiglu_q8_bytes];
5301 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5313 int aligned_embed_dim,
5314 int aligned_head_dim,
5315 int aligned_intermediate_dim,
5316 int aligned_context_window
5318 const MODELLayerOffsets *L = &MODEL_LAYERS[5];
5320 float *input = MODEL_PTR(model, MODEL_LAYERS[4].output);
5322 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5323 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5324 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5325 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5326 float *k_cache = MODEL_PTR(model, L->k);
5327 float *v_cache = MODEL_PTR(model, L->v);
5328 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5329 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5330 float *residual1 = MODEL_PTR(model, L->residual1);
5331 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5332 float *output = MODEL_PTR(model, L->output);
5335 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5336 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5337 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5338 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5339 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5340 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5342 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5343 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5349 float q_token[H * aligned_head_dim];
5350 float k_token[H_kv * aligned_head_dim];
5351 float v_token[H_kv * aligned_head_dim];
5352 float attn_token[H * aligned_head_dim];
5355 float fc1_out[2 * aligned_intermediate_dim];
5356 float swiglu_out[aligned_intermediate_dim];
5368 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5372 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5373 uint8_t ln1_q8[ln1_q8_bytes];
5375 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5376 if (aligned_head_dim > head_dim) {
5377 for (
int h = 0; h < H; ++h) {
5378 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5379 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5386 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5388 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5390 for (
int h = 0; h < H_kv; ++h) {
5391 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5392 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5393 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5394 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5400 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5402 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5404 for (
int h = 0; h < H_kv; ++h) {
5405 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5406 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5407 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5408 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5422 for (
int h = 0; h < H_kv; ++h) {
5423 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5444 aligned_context_window,
5450 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5451 uint8_t attn_q8[attn_q8_bytes];
5453 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5470 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5471 uint8_t ln2_q8[ln2_q8_bytes];
5473 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5476 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5479 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5480 uint8_t swiglu_q8[swiglu_q8_bytes];
5482 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5494 int aligned_embed_dim,
5495 int aligned_head_dim,
5496 int aligned_intermediate_dim,
5497 int aligned_context_window
5499 const MODELLayerOffsets *L = &MODEL_LAYERS[6];
5501 float *input = MODEL_PTR(model, MODEL_LAYERS[5].output);
5503 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5504 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5505 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5506 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5507 float *k_cache = MODEL_PTR(model, L->k);
5508 float *v_cache = MODEL_PTR(model, L->v);
5509 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5510 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5511 float *residual1 = MODEL_PTR(model, L->residual1);
5512 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5513 float *output = MODEL_PTR(model, L->output);
5516 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5517 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5518 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5519 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5520 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5521 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5523 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5524 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5530 float q_token[H * aligned_head_dim];
5531 float k_token[H_kv * aligned_head_dim];
5532 float v_token[H_kv * aligned_head_dim];
5533 float attn_token[H * aligned_head_dim];
5536 float fc1_out[2 * aligned_intermediate_dim];
5537 float swiglu_out[aligned_intermediate_dim];
5549 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5553 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5554 uint8_t ln1_q8[ln1_q8_bytes];
5556 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5557 if (aligned_head_dim > head_dim) {
5558 for (
int h = 0; h < H; ++h) {
5559 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5560 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5567 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5569 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5571 for (
int h = 0; h < H_kv; ++h) {
5572 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5573 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5574 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5575 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5581 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5583 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5585 for (
int h = 0; h < H_kv; ++h) {
5586 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5587 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5588 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5589 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5603 for (
int h = 0; h < H_kv; ++h) {
5604 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5625 aligned_context_window,
5631 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5632 uint8_t attn_q8[attn_q8_bytes];
5634 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5651 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5652 uint8_t ln2_q8[ln2_q8_bytes];
5654 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5657 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5660 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5661 uint8_t swiglu_q8[swiglu_q8_bytes];
5663 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5675 int aligned_embed_dim,
5676 int aligned_head_dim,
5677 int aligned_intermediate_dim,
5678 int aligned_context_window
5680 const MODELLayerOffsets *L = &MODEL_LAYERS[7];
5682 float *input = MODEL_PTR(model, MODEL_LAYERS[6].output);
5684 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5685 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5686 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5687 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5688 float *k_cache = MODEL_PTR(model, L->k);
5689 float *v_cache = MODEL_PTR(model, L->v);
5690 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5691 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5692 float *residual1 = MODEL_PTR(model, L->residual1);
5693 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5694 float *output = MODEL_PTR(model, L->output);
5697 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5698 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5699 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5700 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5701 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5702 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5704 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5705 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5711 float q_token[H * aligned_head_dim];
5712 float k_token[H_kv * aligned_head_dim];
5713 float v_token[H_kv * aligned_head_dim];
5714 float attn_token[H * aligned_head_dim];
5717 float fc1_out[2 * aligned_intermediate_dim];
5718 float swiglu_out[aligned_intermediate_dim];
5730 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5734 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5735 uint8_t ln1_q8[ln1_q8_bytes];
5737 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5738 if (aligned_head_dim > head_dim) {
5739 for (
int h = 0; h < H; ++h) {
5740 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5741 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5748 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5750 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5752 for (
int h = 0; h < H_kv; ++h) {
5753 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5754 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5755 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5756 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5762 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5764 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5766 for (
int h = 0; h < H_kv; ++h) {
5767 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5768 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5769 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5770 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5784 for (
int h = 0; h < H_kv; ++h) {
5785 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5806 aligned_context_window,
5812 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5813 uint8_t attn_q8[attn_q8_bytes];
5815 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
5832 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5833 uint8_t ln2_q8[ln2_q8_bytes];
5835 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
5838 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
5841 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
5842 uint8_t swiglu_q8[swiglu_q8_bytes];
5844 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
5856 int aligned_embed_dim,
5857 int aligned_head_dim,
5858 int aligned_intermediate_dim,
5859 int aligned_context_window
5861 const MODELLayerOffsets *L = &MODEL_LAYERS[8];
5863 float *input = MODEL_PTR(model, MODEL_LAYERS[7].output);
5865 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
5866 float *ln1_out = MODEL_PTR(model, L->ln1_out);
5867 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
5868 float *ln2_out = MODEL_PTR(model, L->ln2_out);
5869 float *k_cache = MODEL_PTR(model, L->k);
5870 float *v_cache = MODEL_PTR(model, L->v);
5871 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
5872 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
5873 float *residual1 = MODEL_PTR(model, L->residual1);
5874 float *mlp_out = MODEL_PTR(model, L->mlp_out);
5875 float *output = MODEL_PTR(model, L->output);
5878 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
5879 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
5880 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
5881 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
5882 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
5883 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
5885 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
5886 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
5892 float q_token[H * aligned_head_dim];
5893 float k_token[H_kv * aligned_head_dim];
5894 float v_token[H_kv * aligned_head_dim];
5895 float attn_token[H * aligned_head_dim];
5898 float fc1_out[2 * aligned_intermediate_dim];
5899 float swiglu_out[aligned_intermediate_dim];
5911 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
5915 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
5916 uint8_t ln1_q8[ln1_q8_bytes];
5918 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
5919 if (aligned_head_dim > head_dim) {
5920 for (
int h = 0; h < H; ++h) {
5921 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
5922 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5929 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5931 const uint8_t *WK_bytes = (
const uint8_t *)WK;
5933 for (
int h = 0; h < H_kv; ++h) {
5934 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
5935 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5936 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
5937 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5943 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
5945 const uint8_t *WV_bytes = (
const uint8_t *)WV;
5947 for (
int h = 0; h < H_kv; ++h) {
5948 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
5949 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5950 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
5951 for (
int d = head_dim; d < aligned_head_dim; ++d) {
5965 for (
int h = 0; h < H_kv; ++h) {
5966 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
5987 aligned_context_window,
5993 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
5994 uint8_t attn_q8[attn_q8_bytes];
5996 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6013 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6014 uint8_t ln2_q8[ln2_q8_bytes];
6016 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6019 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6022 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6023 uint8_t swiglu_q8[swiglu_q8_bytes];
6025 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6037 int aligned_embed_dim,
6038 int aligned_head_dim,
6039 int aligned_intermediate_dim,
6040 int aligned_context_window
6042 const MODELLayerOffsets *L = &MODEL_LAYERS[9];
6044 float *input = MODEL_PTR(model, MODEL_LAYERS[8].output);
6046 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6047 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6048 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6049 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6050 float *k_cache = MODEL_PTR(model, L->k);
6051 float *v_cache = MODEL_PTR(model, L->v);
6052 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6053 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6054 float *residual1 = MODEL_PTR(model, L->residual1);
6055 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6056 float *output = MODEL_PTR(model, L->output);
6059 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6060 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6061 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6062 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6063 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6064 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6066 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6067 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6073 float q_token[H * aligned_head_dim];
6074 float k_token[H_kv * aligned_head_dim];
6075 float v_token[H_kv * aligned_head_dim];
6076 float attn_token[H * aligned_head_dim];
6079 float fc1_out[2 * aligned_intermediate_dim];
6080 float swiglu_out[aligned_intermediate_dim];
6092 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6096 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6097 uint8_t ln1_q8[ln1_q8_bytes];
6099 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6100 if (aligned_head_dim > head_dim) {
6101 for (
int h = 0; h < H; ++h) {
6102 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6103 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6110 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6112 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6114 for (
int h = 0; h < H_kv; ++h) {
6115 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6116 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6117 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6118 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6124 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6126 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6128 for (
int h = 0; h < H_kv; ++h) {
6129 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6130 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6131 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6132 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6146 for (
int h = 0; h < H_kv; ++h) {
6147 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6168 aligned_context_window,
6174 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6175 uint8_t attn_q8[attn_q8_bytes];
6177 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6194 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6195 uint8_t ln2_q8[ln2_q8_bytes];
6197 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6200 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6203 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6204 uint8_t swiglu_q8[swiglu_q8_bytes];
6206 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6218 int aligned_embed_dim,
6219 int aligned_head_dim,
6220 int aligned_intermediate_dim,
6221 int aligned_context_window
6223 const MODELLayerOffsets *L = &MODEL_LAYERS[10];
6225 float *input = MODEL_PTR(model, MODEL_LAYERS[9].output);
6227 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6228 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6229 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6230 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6231 float *k_cache = MODEL_PTR(model, L->k);
6232 float *v_cache = MODEL_PTR(model, L->v);
6233 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6234 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6235 float *residual1 = MODEL_PTR(model, L->residual1);
6236 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6237 float *output = MODEL_PTR(model, L->output);
6240 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6241 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6242 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6243 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6244 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6245 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6247 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6248 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6254 float q_token[H * aligned_head_dim];
6255 float k_token[H_kv * aligned_head_dim];
6256 float v_token[H_kv * aligned_head_dim];
6257 float attn_token[H * aligned_head_dim];
6260 float fc1_out[2 * aligned_intermediate_dim];
6261 float swiglu_out[aligned_intermediate_dim];
6273 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6277 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6278 uint8_t ln1_q8[ln1_q8_bytes];
6280 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6281 if (aligned_head_dim > head_dim) {
6282 for (
int h = 0; h < H; ++h) {
6283 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6284 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6291 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6293 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6295 for (
int h = 0; h < H_kv; ++h) {
6296 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6297 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6298 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6299 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6305 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6307 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6309 for (
int h = 0; h < H_kv; ++h) {
6310 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6311 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6312 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6313 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6327 for (
int h = 0; h < H_kv; ++h) {
6328 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6349 aligned_context_window,
6355 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6356 uint8_t attn_q8[attn_q8_bytes];
6358 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6375 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6376 uint8_t ln2_q8[ln2_q8_bytes];
6378 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6381 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6384 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6385 uint8_t swiglu_q8[swiglu_q8_bytes];
6387 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6399 int aligned_embed_dim,
6400 int aligned_head_dim,
6401 int aligned_intermediate_dim,
6402 int aligned_context_window
6404 const MODELLayerOffsets *L = &MODEL_LAYERS[11];
6406 float *input = MODEL_PTR(model, MODEL_LAYERS[10].output);
6408 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6409 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6410 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6411 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6412 float *k_cache = MODEL_PTR(model, L->k);
6413 float *v_cache = MODEL_PTR(model, L->v);
6414 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6415 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6416 float *residual1 = MODEL_PTR(model, L->residual1);
6417 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6418 float *output = MODEL_PTR(model, L->output);
6421 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6422 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6423 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6424 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6425 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6426 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6428 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6429 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6435 float q_token[H * aligned_head_dim];
6436 float k_token[H_kv * aligned_head_dim];
6437 float v_token[H_kv * aligned_head_dim];
6438 float attn_token[H * aligned_head_dim];
6441 float fc1_out[2 * aligned_intermediate_dim];
6442 float swiglu_out[aligned_intermediate_dim];
6454 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6458 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6459 uint8_t ln1_q8[ln1_q8_bytes];
6461 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6462 if (aligned_head_dim > head_dim) {
6463 for (
int h = 0; h < H; ++h) {
6464 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6465 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6472 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6474 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6476 for (
int h = 0; h < H_kv; ++h) {
6477 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6478 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6479 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6480 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6486 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6488 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6490 for (
int h = 0; h < H_kv; ++h) {
6491 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6492 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6493 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6494 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6508 for (
int h = 0; h < H_kv; ++h) {
6509 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6530 aligned_context_window,
6536 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6537 uint8_t attn_q8[attn_q8_bytes];
6539 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6556 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6557 uint8_t ln2_q8[ln2_q8_bytes];
6559 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6562 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6565 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6566 uint8_t swiglu_q8[swiglu_q8_bytes];
6568 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6580 int aligned_embed_dim,
6581 int aligned_head_dim,
6582 int aligned_intermediate_dim,
6583 int aligned_context_window
6585 const MODELLayerOffsets *L = &MODEL_LAYERS[12];
6587 float *input = MODEL_PTR(model, MODEL_LAYERS[11].output);
6589 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6590 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6591 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6592 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6593 float *k_cache = MODEL_PTR(model, L->k);
6594 float *v_cache = MODEL_PTR(model, L->v);
6595 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6596 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6597 float *residual1 = MODEL_PTR(model, L->residual1);
6598 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6599 float *output = MODEL_PTR(model, L->output);
6602 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6603 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6604 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6605 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6606 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6607 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6609 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6610 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6616 float q_token[H * aligned_head_dim];
6617 float k_token[H_kv * aligned_head_dim];
6618 float v_token[H_kv * aligned_head_dim];
6619 float attn_token[H * aligned_head_dim];
6622 float fc1_out[2 * aligned_intermediate_dim];
6623 float swiglu_out[aligned_intermediate_dim];
6635 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6639 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6640 uint8_t ln1_q8[ln1_q8_bytes];
6642 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6643 if (aligned_head_dim > head_dim) {
6644 for (
int h = 0; h < H; ++h) {
6645 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6646 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6653 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6655 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6657 for (
int h = 0; h < H_kv; ++h) {
6658 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6659 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6660 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6661 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6667 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6669 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6671 for (
int h = 0; h < H_kv; ++h) {
6672 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6673 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6674 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6675 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6689 for (
int h = 0; h < H_kv; ++h) {
6690 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6711 aligned_context_window,
6717 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6718 uint8_t attn_q8[attn_q8_bytes];
6720 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6737 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6738 uint8_t ln2_q8[ln2_q8_bytes];
6740 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6743 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6746 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6747 uint8_t swiglu_q8[swiglu_q8_bytes];
6749 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6761 int aligned_embed_dim,
6762 int aligned_head_dim,
6763 int aligned_intermediate_dim,
6764 int aligned_context_window
6766 const MODELLayerOffsets *L = &MODEL_LAYERS[13];
6768 float *input = MODEL_PTR(model, MODEL_LAYERS[12].output);
6770 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6771 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6772 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6773 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6774 float *k_cache = MODEL_PTR(model, L->k);
6775 float *v_cache = MODEL_PTR(model, L->v);
6776 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6777 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6778 float *residual1 = MODEL_PTR(model, L->residual1);
6779 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6780 float *output = MODEL_PTR(model, L->output);
6783 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6784 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6785 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6786 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6787 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6788 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6790 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6791 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6797 float q_token[H * aligned_head_dim];
6798 float k_token[H_kv * aligned_head_dim];
6799 float v_token[H_kv * aligned_head_dim];
6800 float attn_token[H * aligned_head_dim];
6803 float fc1_out[2 * aligned_intermediate_dim];
6804 float swiglu_out[aligned_intermediate_dim];
6816 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
6820 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6821 uint8_t ln1_q8[ln1_q8_bytes];
6823 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
6824 if (aligned_head_dim > head_dim) {
6825 for (
int h = 0; h < H; ++h) {
6826 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
6827 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6834 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6836 const uint8_t *WK_bytes = (
const uint8_t *)WK;
6838 for (
int h = 0; h < H_kv; ++h) {
6839 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
6840 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6841 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
6842 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6848 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
6850 const uint8_t *WV_bytes = (
const uint8_t *)WV;
6852 for (
int h = 0; h < H_kv; ++h) {
6853 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
6854 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6855 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
6856 for (
int d = head_dim; d < aligned_head_dim; ++d) {
6870 for (
int h = 0; h < H_kv; ++h) {
6871 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
6892 aligned_context_window,
6898 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
6899 uint8_t attn_q8[attn_q8_bytes];
6901 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
6918 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
6919 uint8_t ln2_q8[ln2_q8_bytes];
6921 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
6924 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
6927 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
6928 uint8_t swiglu_q8[swiglu_q8_bytes];
6930 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
6942 int aligned_embed_dim,
6943 int aligned_head_dim,
6944 int aligned_intermediate_dim,
6945 int aligned_context_window
6947 const MODELLayerOffsets *L = &MODEL_LAYERS[14];
6949 float *input = MODEL_PTR(model, MODEL_LAYERS[13].output);
6951 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
6952 float *ln1_out = MODEL_PTR(model, L->ln1_out);
6953 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
6954 float *ln2_out = MODEL_PTR(model, L->ln2_out);
6955 float *k_cache = MODEL_PTR(model, L->k);
6956 float *v_cache = MODEL_PTR(model, L->v);
6957 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
6958 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
6959 float *residual1 = MODEL_PTR(model, L->residual1);
6960 float *mlp_out = MODEL_PTR(model, L->mlp_out);
6961 float *output = MODEL_PTR(model, L->output);
6964 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
6965 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
6966 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
6967 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
6968 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
6969 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
6971 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
6972 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
6978 float q_token[H * aligned_head_dim];
6979 float k_token[H_kv * aligned_head_dim];
6980 float v_token[H_kv * aligned_head_dim];
6981 float attn_token[H * aligned_head_dim];
6984 float fc1_out[2 * aligned_intermediate_dim];
6985 float swiglu_out[aligned_intermediate_dim];
6997 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7001 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7002 uint8_t ln1_q8[ln1_q8_bytes];
7004 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7005 if (aligned_head_dim > head_dim) {
7006 for (
int h = 0; h < H; ++h) {
7007 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7008 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7015 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7017 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7019 for (
int h = 0; h < H_kv; ++h) {
7020 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7021 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7022 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7023 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7029 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7031 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7033 for (
int h = 0; h < H_kv; ++h) {
7034 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7035 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7036 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7037 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7051 for (
int h = 0; h < H_kv; ++h) {
7052 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7073 aligned_context_window,
7079 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7080 uint8_t attn_q8[attn_q8_bytes];
7082 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7099 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7100 uint8_t ln2_q8[ln2_q8_bytes];
7102 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7105 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7108 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7109 uint8_t swiglu_q8[swiglu_q8_bytes];
7111 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7123 int aligned_embed_dim,
7124 int aligned_head_dim,
7125 int aligned_intermediate_dim,
7126 int aligned_context_window
7128 const MODELLayerOffsets *L = &MODEL_LAYERS[15];
7130 float *input = MODEL_PTR(model, MODEL_LAYERS[14].output);
7132 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7133 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7134 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7135 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7136 float *k_cache = MODEL_PTR(model, L->k);
7137 float *v_cache = MODEL_PTR(model, L->v);
7138 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7139 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7140 float *residual1 = MODEL_PTR(model, L->residual1);
7141 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7142 float *output = MODEL_PTR(model, L->output);
7145 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7146 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7147 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7148 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7149 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7150 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7152 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7153 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7159 float q_token[H * aligned_head_dim];
7160 float k_token[H_kv * aligned_head_dim];
7161 float v_token[H_kv * aligned_head_dim];
7162 float attn_token[H * aligned_head_dim];
7165 float fc1_out[2 * aligned_intermediate_dim];
7166 float swiglu_out[aligned_intermediate_dim];
7178 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7182 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7183 uint8_t ln1_q8[ln1_q8_bytes];
7185 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7186 if (aligned_head_dim > head_dim) {
7187 for (
int h = 0; h < H; ++h) {
7188 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7189 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7196 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7198 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7200 for (
int h = 0; h < H_kv; ++h) {
7201 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7202 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7203 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7204 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7210 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7212 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7214 for (
int h = 0; h < H_kv; ++h) {
7215 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7216 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7217 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7218 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7232 for (
int h = 0; h < H_kv; ++h) {
7233 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7254 aligned_context_window,
7260 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7261 uint8_t attn_q8[attn_q8_bytes];
7263 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7280 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7281 uint8_t ln2_q8[ln2_q8_bytes];
7283 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7286 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7289 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7290 uint8_t swiglu_q8[swiglu_q8_bytes];
7292 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7304 int aligned_embed_dim,
7305 int aligned_head_dim,
7306 int aligned_intermediate_dim,
7307 int aligned_context_window
7309 const MODELLayerOffsets *L = &MODEL_LAYERS[16];
7311 float *input = MODEL_PTR(model, MODEL_LAYERS[15].output);
7313 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7314 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7315 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7316 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7317 float *k_cache = MODEL_PTR(model, L->k);
7318 float *v_cache = MODEL_PTR(model, L->v);
7319 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7320 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7321 float *residual1 = MODEL_PTR(model, L->residual1);
7322 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7323 float *output = MODEL_PTR(model, L->output);
7326 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7327 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7328 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7329 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7330 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7331 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7333 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7334 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7340 float q_token[H * aligned_head_dim];
7341 float k_token[H_kv * aligned_head_dim];
7342 float v_token[H_kv * aligned_head_dim];
7343 float attn_token[H * aligned_head_dim];
7346 float fc1_out[2 * aligned_intermediate_dim];
7347 float swiglu_out[aligned_intermediate_dim];
7359 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7363 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7364 uint8_t ln1_q8[ln1_q8_bytes];
7366 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7367 if (aligned_head_dim > head_dim) {
7368 for (
int h = 0; h < H; ++h) {
7369 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7370 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7377 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7379 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7381 for (
int h = 0; h < H_kv; ++h) {
7382 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7383 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7384 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7385 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7391 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7393 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7395 for (
int h = 0; h < H_kv; ++h) {
7396 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7397 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7398 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7399 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7413 for (
int h = 0; h < H_kv; ++h) {
7414 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7435 aligned_context_window,
7441 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7442 uint8_t attn_q8[attn_q8_bytes];
7444 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7461 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7462 uint8_t ln2_q8[ln2_q8_bytes];
7464 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7467 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7470 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7471 uint8_t swiglu_q8[swiglu_q8_bytes];
7473 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7485 int aligned_embed_dim,
7486 int aligned_head_dim,
7487 int aligned_intermediate_dim,
7488 int aligned_context_window
7490 const MODELLayerOffsets *L = &MODEL_LAYERS[17];
7492 float *input = MODEL_PTR(model, MODEL_LAYERS[16].output);
7494 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7495 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7496 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7497 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7498 float *k_cache = MODEL_PTR(model, L->k);
7499 float *v_cache = MODEL_PTR(model, L->v);
7500 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7501 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7502 float *residual1 = MODEL_PTR(model, L->residual1);
7503 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7504 float *output = MODEL_PTR(model, L->output);
7507 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7508 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7509 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7510 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7511 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7512 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7514 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7515 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7521 float q_token[H * aligned_head_dim];
7522 float k_token[H_kv * aligned_head_dim];
7523 float v_token[H_kv * aligned_head_dim];
7524 float attn_token[H * aligned_head_dim];
7527 float fc1_out[2 * aligned_intermediate_dim];
7528 float swiglu_out[aligned_intermediate_dim];
7540 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7544 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7545 uint8_t ln1_q8[ln1_q8_bytes];
7547 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7548 if (aligned_head_dim > head_dim) {
7549 for (
int h = 0; h < H; ++h) {
7550 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7551 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7558 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7560 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7562 for (
int h = 0; h < H_kv; ++h) {
7563 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7564 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7565 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7566 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7572 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7574 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7576 for (
int h = 0; h < H_kv; ++h) {
7577 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7578 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7579 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7580 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7594 for (
int h = 0; h < H_kv; ++h) {
7595 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7616 aligned_context_window,
7622 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7623 uint8_t attn_q8[attn_q8_bytes];
7625 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7642 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7643 uint8_t ln2_q8[ln2_q8_bytes];
7645 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7648 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7651 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7652 uint8_t swiglu_q8[swiglu_q8_bytes];
7654 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7666 int aligned_embed_dim,
7667 int aligned_head_dim,
7668 int aligned_intermediate_dim,
7669 int aligned_context_window
7671 const MODELLayerOffsets *L = &MODEL_LAYERS[18];
7673 float *input = MODEL_PTR(model, MODEL_LAYERS[17].output);
7675 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7676 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7677 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7678 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7679 float *k_cache = MODEL_PTR(model, L->k);
7680 float *v_cache = MODEL_PTR(model, L->v);
7681 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7682 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7683 float *residual1 = MODEL_PTR(model, L->residual1);
7684 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7685 float *output = MODEL_PTR(model, L->output);
7688 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7689 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7690 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7691 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7692 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7693 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7695 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7696 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7702 float q_token[H * aligned_head_dim];
7703 float k_token[H_kv * aligned_head_dim];
7704 float v_token[H_kv * aligned_head_dim];
7705 float attn_token[H * aligned_head_dim];
7708 float fc1_out[2 * aligned_intermediate_dim];
7709 float swiglu_out[aligned_intermediate_dim];
7721 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7725 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7726 uint8_t ln1_q8[ln1_q8_bytes];
7728 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7729 if (aligned_head_dim > head_dim) {
7730 for (
int h = 0; h < H; ++h) {
7731 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7732 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7739 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7741 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7743 for (
int h = 0; h < H_kv; ++h) {
7744 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7745 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7746 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7747 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7753 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7755 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7757 for (
int h = 0; h < H_kv; ++h) {
7758 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7759 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7760 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7761 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7775 for (
int h = 0; h < H_kv; ++h) {
7776 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7797 aligned_context_window,
7803 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7804 uint8_t attn_q8[attn_q8_bytes];
7806 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
7823 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7824 uint8_t ln2_q8[ln2_q8_bytes];
7826 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
7829 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
7832 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
7833 uint8_t swiglu_q8[swiglu_q8_bytes];
7835 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
7847 int aligned_embed_dim,
7848 int aligned_head_dim,
7849 int aligned_intermediate_dim,
7850 int aligned_context_window
7852 const MODELLayerOffsets *L = &MODEL_LAYERS[19];
7854 float *input = MODEL_PTR(model, MODEL_LAYERS[18].output);
7856 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
7857 float *ln1_out = MODEL_PTR(model, L->ln1_out);
7858 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
7859 float *ln2_out = MODEL_PTR(model, L->ln2_out);
7860 float *k_cache = MODEL_PTR(model, L->k);
7861 float *v_cache = MODEL_PTR(model, L->v);
7862 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
7863 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
7864 float *residual1 = MODEL_PTR(model, L->residual1);
7865 float *mlp_out = MODEL_PTR(model, L->mlp_out);
7866 float *output = MODEL_PTR(model, L->output);
7869 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
7870 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
7871 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
7872 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
7873 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
7874 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
7876 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
7877 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
7883 float q_token[H * aligned_head_dim];
7884 float k_token[H_kv * aligned_head_dim];
7885 float v_token[H_kv * aligned_head_dim];
7886 float attn_token[H * aligned_head_dim];
7889 float fc1_out[2 * aligned_intermediate_dim];
7890 float swiglu_out[aligned_intermediate_dim];
7902 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
7906 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
7907 uint8_t ln1_q8[ln1_q8_bytes];
7909 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
7910 if (aligned_head_dim > head_dim) {
7911 for (
int h = 0; h < H; ++h) {
7912 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
7913 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7920 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7922 const uint8_t *WK_bytes = (
const uint8_t *)WK;
7924 for (
int h = 0; h < H_kv; ++h) {
7925 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
7926 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7927 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
7928 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7934 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
7936 const uint8_t *WV_bytes = (
const uint8_t *)WV;
7938 for (
int h = 0; h < H_kv; ++h) {
7939 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
7940 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7941 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
7942 for (
int d = head_dim; d < aligned_head_dim; ++d) {
7956 for (
int h = 0; h < H_kv; ++h) {
7957 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
7978 aligned_context_window,
7984 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
7985 uint8_t attn_q8[attn_q8_bytes];
7987 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8004 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8005 uint8_t ln2_q8[ln2_q8_bytes];
8007 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8010 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8013 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8014 uint8_t swiglu_q8[swiglu_q8_bytes];
8016 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8028 int aligned_embed_dim,
8029 int aligned_head_dim,
8030 int aligned_intermediate_dim,
8031 int aligned_context_window
8033 const MODELLayerOffsets *L = &MODEL_LAYERS[20];
8035 float *input = MODEL_PTR(model, MODEL_LAYERS[19].output);
8037 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
8038 float *ln1_out = MODEL_PTR(model, L->ln1_out);
8039 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
8040 float *ln2_out = MODEL_PTR(model, L->ln2_out);
8041 float *k_cache = MODEL_PTR(model, L->k);
8042 float *v_cache = MODEL_PTR(model, L->v);
8043 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
8044 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
8045 float *residual1 = MODEL_PTR(model, L->residual1);
8046 float *mlp_out = MODEL_PTR(model, L->mlp_out);
8047 float *output = MODEL_PTR(model, L->output);
8050 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
8051 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
8052 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
8053 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
8054 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
8055 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
8057 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
8058 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
8064 float q_token[H * aligned_head_dim];
8065 float k_token[H_kv * aligned_head_dim];
8066 float v_token[H_kv * aligned_head_dim];
8067 float attn_token[H * aligned_head_dim];
8070 float fc1_out[2 * aligned_intermediate_dim];
8071 float swiglu_out[aligned_intermediate_dim];
8083 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8087 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8088 uint8_t ln1_q8[ln1_q8_bytes];
8090 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8091 if (aligned_head_dim > head_dim) {
8092 for (
int h = 0; h < H; ++h) {
8093 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8094 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8101 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8103 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8105 for (
int h = 0; h < H_kv; ++h) {
8106 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8107 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8108 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8109 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8115 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8117 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8119 for (
int h = 0; h < H_kv; ++h) {
8120 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8121 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8122 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8123 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8137 for (
int h = 0; h < H_kv; ++h) {
8138 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8159 aligned_context_window,
8165 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8166 uint8_t attn_q8[attn_q8_bytes];
8168 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8185 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8186 uint8_t ln2_q8[ln2_q8_bytes];
8188 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8191 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8194 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8195 uint8_t swiglu_q8[swiglu_q8_bytes];
8197 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8209 int aligned_embed_dim,
8210 int aligned_head_dim,
8211 int aligned_intermediate_dim,
8212 int aligned_context_window
8214 const MODELLayerOffsets *L = &MODEL_LAYERS[21];
8216 float *input = MODEL_PTR(model, MODEL_LAYERS[20].output);
8218 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
8219 float *ln1_out = MODEL_PTR(model, L->ln1_out);
8220 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
8221 float *ln2_out = MODEL_PTR(model, L->ln2_out);
8222 float *k_cache = MODEL_PTR(model, L->k);
8223 float *v_cache = MODEL_PTR(model, L->v);
8224 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
8225 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
8226 float *residual1 = MODEL_PTR(model, L->residual1);
8227 float *mlp_out = MODEL_PTR(model, L->mlp_out);
8228 float *output = MODEL_PTR(model, L->output);
8231 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
8232 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
8233 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
8234 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
8235 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
8236 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
8238 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
8239 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
8245 float q_token[H * aligned_head_dim];
8246 float k_token[H_kv * aligned_head_dim];
8247 float v_token[H_kv * aligned_head_dim];
8248 float attn_token[H * aligned_head_dim];
8251 float fc1_out[2 * aligned_intermediate_dim];
8252 float swiglu_out[aligned_intermediate_dim];
8264 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8268 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8269 uint8_t ln1_q8[ln1_q8_bytes];
8271 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8272 if (aligned_head_dim > head_dim) {
8273 for (
int h = 0; h < H; ++h) {
8274 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8275 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8282 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8284 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8286 for (
int h = 0; h < H_kv; ++h) {
8287 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8288 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8289 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8290 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8296 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8298 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8300 for (
int h = 0; h < H_kv; ++h) {
8301 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8302 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8303 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8304 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8318 for (
int h = 0; h < H_kv; ++h) {
8319 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8340 aligned_context_window,
8346 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8347 uint8_t attn_q8[attn_q8_bytes];
8349 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8366 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8367 uint8_t ln2_q8[ln2_q8_bytes];
8369 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8372 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8375 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8376 uint8_t swiglu_q8[swiglu_q8_bytes];
8378 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8390 int aligned_embed_dim,
8391 int aligned_head_dim,
8392 int aligned_intermediate_dim,
8393 int aligned_context_window
8395 const MODELLayerOffsets *L = &MODEL_LAYERS[22];
8397 float *input = MODEL_PTR(model, MODEL_LAYERS[21].output);
8399 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
8400 float *ln1_out = MODEL_PTR(model, L->ln1_out);
8401 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
8402 float *ln2_out = MODEL_PTR(model, L->ln2_out);
8403 float *k_cache = MODEL_PTR(model, L->k);
8404 float *v_cache = MODEL_PTR(model, L->v);
8405 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
8406 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
8407 float *residual1 = MODEL_PTR(model, L->residual1);
8408 float *mlp_out = MODEL_PTR(model, L->mlp_out);
8409 float *output = MODEL_PTR(model, L->output);
8412 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
8413 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
8414 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
8415 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
8416 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
8417 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
8419 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
8420 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
8426 float q_token[H * aligned_head_dim];
8427 float k_token[H_kv * aligned_head_dim];
8428 float v_token[H_kv * aligned_head_dim];
8429 float attn_token[H * aligned_head_dim];
8432 float fc1_out[2 * aligned_intermediate_dim];
8433 float swiglu_out[aligned_intermediate_dim];
8445 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8449 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8450 uint8_t ln1_q8[ln1_q8_bytes];
8452 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8453 if (aligned_head_dim > head_dim) {
8454 for (
int h = 0; h < H; ++h) {
8455 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8456 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8463 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8465 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8467 for (
int h = 0; h < H_kv; ++h) {
8468 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8469 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8470 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8471 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8477 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8479 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8481 for (
int h = 0; h < H_kv; ++h) {
8482 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8483 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8484 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8485 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8499 for (
int h = 0; h < H_kv; ++h) {
8500 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8521 aligned_context_window,
8527 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8528 uint8_t attn_q8[attn_q8_bytes];
8530 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8547 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8548 uint8_t ln2_q8[ln2_q8_bytes];
8550 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8553 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8556 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8557 uint8_t swiglu_q8[swiglu_q8_bytes];
8559 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8571 int aligned_embed_dim,
8572 int aligned_head_dim,
8573 int aligned_intermediate_dim,
8574 int aligned_context_window
8576 const MODELLayerOffsets *L = &MODEL_LAYERS[23];
8578 float *input = MODEL_PTR(model, MODEL_LAYERS[22].output);
8580 float *ln1_gamma = MODEL_PTR(model, L->ln1_gamma);
8581 float *ln1_out = MODEL_PTR(model, L->ln1_out);
8582 float *ln2_gamma = MODEL_PTR(model, L->ln2_gamma);
8583 float *ln2_out = MODEL_PTR(model, L->ln2_out);
8584 float *k_cache = MODEL_PTR(model, L->k);
8585 float *v_cache = MODEL_PTR(model, L->v);
8586 float *proj_tmp = MODEL_PTR(model, L->proj_tmp);
8587 float *proj_scratch = MODEL_PTR(model, L->proj_scratch);
8588 float *residual1 = MODEL_PTR(model, L->residual1);
8589 float *mlp_out = MODEL_PTR(model, L->mlp_out);
8590 float *output = MODEL_PTR(model, L->output);
8593 const void *WQ = (
const void *)MODEL_PTR(model, L->wq);
8594 const void *WK = (
const void *)MODEL_PTR(model, L->wk);
8595 const void *WV = (
const void *)MODEL_PTR(model, L->wv);
8596 const void *WO = (
const void *)MODEL_PTR(model, L->wo);
8597 const void *W1 = (
const void *)MODEL_PTR(model, L->w1);
8598 const void *W2 = (
const void *)MODEL_PTR(model, L->w2);
8600 float *rope_cos = MODEL_PTR(model, MODEL_GLOBALS.rope_cos_cache);
8601 float *rope_sin = MODEL_PTR(model, MODEL_GLOBALS.rope_sin_cache);
8607 float q_token[H * aligned_head_dim];
8608 float k_token[H_kv * aligned_head_dim];
8609 float v_token[H_kv * aligned_head_dim];
8610 float attn_token[H * aligned_head_dim];
8613 float fc1_out[2 * aligned_intermediate_dim];
8614 float swiglu_out[aligned_intermediate_dim];
8626 const size_t kv_head_stride = (size_t)aligned_context_window * (
size_t)aligned_head_dim;
8630 const size_t ln1_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8631 uint8_t ln1_q8[ln1_q8_bytes];
8633 gemv_q4_k_q8_k(q_token, WQ, ln1_q8, H * head_dim, aligned_embed_dim);
8634 if (aligned_head_dim > head_dim) {
8635 for (
int h = 0; h < H; ++h) {
8636 float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
8637 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8644 const size_t wk_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8646 const uint8_t *WK_bytes = (
const uint8_t *)WK;
8648 for (
int h = 0; h < H_kv; ++h) {
8649 const void *wk_h = (
const void *)(WK_bytes + (
size_t)h * wk_head_bytes);
8650 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8651 gemv_q4_k_q8_k(k_head, wk_h, ln1_q8, head_dim, aligned_embed_dim);
8652 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8658 const size_t wv_head_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
8660 const uint8_t *WV_bytes = (
const uint8_t *)WV;
8662 for (
int h = 0; h < H_kv; ++h) {
8663 const void *wv_h = (
const void *)(WV_bytes + (
size_t)h * wv_head_bytes);
8664 float *v_head = v_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8665 gemv_q4_k_q8_k(v_head, wv_h, ln1_q8, head_dim, aligned_embed_dim);
8666 for (
int d = head_dim; d < aligned_head_dim; ++d) {
8680 for (
int h = 0; h < H_kv; ++h) {
8681 float *k_head = k_cache + (size_t)h * kv_head_stride + (
size_t)token_index * (size_t)aligned_head_dim;
8702 aligned_context_window,
8708 const size_t attn_q8_bytes = ((((H * head_dim) + 255) / 256) * 292);
8709 uint8_t attn_q8[attn_q8_bytes];
8711 gemv_q4_k_q8_k(proj_tmp, WO, attn_q8, aligned_embed_dim, H * head_dim);
8728 const size_t ln2_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8729 uint8_t ln2_q8[ln2_q8_bytes];
8731 gemv_q4_k_q8_k(fc1_out, W1, ln2_q8, 2 * aligned_intermediate_dim, aligned_embed_dim);
8734 swiglu_forward(fc1_out, swiglu_out, 1, aligned_intermediate_dim);
8737 const size_t swiglu_q8_bytes = ((((aligned_intermediate_dim) + 255) / 256) * 292);
8738 uint8_t swiglu_q8[swiglu_q8_bytes];
8740 gemv_q4_k_q8_k(mlp_out, W2, swiglu_q8, aligned_embed_dim, aligned_intermediate_dim);
8755 if (!model || !
token)
return;
8757 const int aligned_embed_dim = 1024;
8758 const int aligned_head_dim = 64;
8759 const int aligned_intermediate_dim = 4864;
8760 const int aligned_context_window = 131072;
8762 if (token_index < 0 || token_index >= aligned_context_window)
return;
8765 float *embed_out = MODEL_PTR(model, MODEL_HEADER.embedded_input);
8766 const void *embed_weight = (
const void *)MODEL_PTR(model, MODEL_HEADER.token_emb);
8780 model_layer_0_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8781 model_layer_1_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8782 model_layer_2_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8783 model_layer_3_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8784 model_layer_4_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8785 model_layer_5_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8786 model_layer_6_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8787 model_layer_7_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8788 model_layer_8_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8789 model_layer_9_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8790 model_layer_10_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8791 model_layer_11_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8792 model_layer_12_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8793 model_layer_13_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8794 model_layer_14_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8795 model_layer_15_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8796 model_layer_16_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8797 model_layer_17_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8798 model_layer_18_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8799 model_layer_19_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8800 model_layer_20_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8801 model_layer_21_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8802 model_layer_22_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8803 model_layer_23_decode(model, token_index, aligned_embed_dim, aligned_head_dim, aligned_intermediate_dim, aligned_context_window);
8806 float *last_hidden = MODEL_PTR(model, MODEL_LAYERS[23].output);
8807 float *final_ln_weight = MODEL_PTR(model, MODEL_FOOTER.final_ln_weight);
8808 float *final_out = MODEL_PTR(model, MODEL_FOOTER.final_output);
8819 float *logits = MODEL_PTR(model, MODEL_FOOTER.logits);
8820 const void *lm_head = (
const void *)MODEL_PTR(model, MODEL_FOOTER.lm_head_weight);
8822 const size_t final_q8_bytes = ((((aligned_embed_dim) + 255) / 256) * 292);
8823 uint8_t final_q8[final_q8_bytes];
8837 if (!model || !tokens || num_tokens <= 0)
return;
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void quantize_row_q8_k(const float *x, void *y, int k)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
static void model_layer_13_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_17_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_13_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_6_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int model_model_allocate(MODELModel *model)
static void model_layer_3_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_0_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_15_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_8_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_22_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_forward_prefill_impl(MODELModel *model, const int *tokens, int num_tokens)
static void model_layer_22_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_18_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_8_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_7_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void model_decode(MODELModel *model, const int *token, int token_index)
void model_precompute_rope(MODELModel *model)
static void model_layer_4_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_11_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_12_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_10_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_16_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
void model_forward(MODELModel *model, const int *tokens, int num_tokens)
static void model_layer_21_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_16_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_4_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_1_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_14_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
int model_verify_canaries(MODELModel *model)
static void model_layer_20_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_3_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_19_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_14_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
void model_model_free(MODELModel *model)
static void model_layer_20_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_11_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_7_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
struct __attribute__((packed))
static void model_layer_23_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_21_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_5_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_18_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_5_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_15_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_10_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_12_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_17_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_0_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static int model_align_elems(int elems, int elem_bytes, int align_bytes)
static void model_layer_1_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_23_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_decode_token(MODELModel *model, const int *token, int token_index)
static void model_layer_19_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
_Static_assert(sizeof(MagicHeader)==64, "MagicHeader must be 64 bytes")
static void model_layer_2_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_2_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_9_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_6_prefill(MODELModel *model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
static void model_layer_9_decode(MODELModel *model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
AUTO-GENERATED: qwen2_0.5b_decode Memory Layout.
#define MODEL_NUM_KV_HEADS
#define MODEL_MAX_SEQ_LEN