27 default:
return "UNKNOWN";
70 return (strcmp(spec->
name,
"token_emb") == 0) ||
71 (strcmp(spec->
name,
"wq") == 0) ||
72 (strcmp(spec->
name,
"wk") == 0) ||
73 (strcmp(spec->
name,
"wv") == 0) ||
74 (strcmp(spec->
name,
"wo") == 0) ||
75 (strcmp(spec->
name,
"w1") == 0) ||
76 (strcmp(spec->
name,
"w2") == 0) ||
77 (strcmp(spec->
name,
"lm_head_weight") == 0);
86 if (strcmp(spec->
name,
"token_emb") == 0) {
87 return "m->token_emb_dtype";
89 if (strcmp(spec->
name,
"lm_head_weight") == 0) {
90 return "m->lm_head_weight_dtype";
95 if (strcmp(spec->
name,
"wq") == 0)
return "L->wq_dtype";
96 if (strcmp(spec->
name,
"wk") == 0)
return "L->wk_dtype";
97 if (strcmp(spec->
name,
"wv") == 0)
return "L->wv_dtype";
98 if (strcmp(spec->
name,
"wo") == 0)
return "L->wo_dtype";
99 if (strcmp(spec->
name,
"w1") == 0)
return "L->w1_dtype";
100 if (strcmp(spec->
name,
"w2") == 0)
return "L->w2_dtype";
108 fprintf(out,
" size_t %s_offset;\n", name);
113 fprintf(out,
"typedef struct {\n");
125 " CKDataType wq_dtype;\n"
126 " CKDataType wk_dtype;\n"
127 " CKDataType wv_dtype;\n"
128 " CKDataType wo_dtype;\n"
129 " CKDataType w1_dtype;\n"
130 " CKDataType w2_dtype;\n");
131 fprintf(out,
"} LayerOffsets;\n\n");
151 "typedef LayerOffsets TrulyOptimalLayer;\n\n"
154 " uint32_t version;\n"
155 " uint32_t model_type;\n"
160 " int context_window;\n"
161 " int intermediate_size;\n"
163 " size_t aligned_embed_dim;\n"
164 " size_t aligned_head_dim;\n"
165 " size_t aligned_attn_context_window;\n"
168 " int tokens_per_core;\n"
169 " int num_attention_heads;\n"
170 " int num_kv_heads;\n"
172 " float rms_norm_eps;\n"
173 " float rope_theta;\n"
175 " uint8_t *memory_base;\n"
176 " size_t total_bytes;\n"
177 " size_t elem_bytes;\n"
178 " CKDataType weight_dtype;\n"
179 " CKDataType token_emb_dtype;\n"
180 " CKDataType pos_emb_dtype;\n"
181 " CKDataType lm_head_weight_dtype;\n"
182 " bool weights_mixed;\n"
183 " bool weights_quantized;\n"
184 " size_t layer_stride;\n"
186 " size_t layers_start_offset;\n");
192 " TrulyOptimalLayer *layers;\n"
194 " GradientStorage gradients;\n"
195 " bool training_enabled;\n"
196 " float learning_rate;\n"
197 " int lr_warmup_steps;\n"
198 " float lr_warmup_init;\n"
199 " float grad_clip;\n"
200 " size_t training_cache_samples;\n"
201 " int active_tokens;\n"
202 " TaskType task_type;\n"
203 " OptimizerType optimizer;\n"
204 " uint64_t optimizer_step;\n"
205 " float adam_beta1;\n"
206 " float adam_beta2;\n"
208 " float weight_decay;\n"
209 " bool ema_enabled;\n"
210 " float ema_decay;\n"
211 " bool optimizer_state_initialized;\n"
213 " bool seq_cls_enabled;\n"
214 " int seq_cls_num_classes;\n"
215 " int seq_cls_pooling;\n"
216 " size_t seq_cls_weight_offset;\n"
217 " size_t seq_cls_bias_offset;\n"
219 " bool kv_cache_enabled;\n"
220 " int kv_cache_capacity;\n"
221 " int kv_cache_tokens;\n"
223 " long *training_data_buffer;\n"
224 " long num_training_tokens;\n"
226 " uint8_t checksum[32];\n"
227 " uint8_t reserved[32];\n"
228 "} TransformerModel;\n\n");
234 case CK_DIM_TOKENS: fprintf(out,
"(size_t)m->context_window");
break;
235 case CK_DIM_EMBED: fprintf(out,
"(size_t)m->embed_dim");
break;
239 case CK_DIM_NUM_HEADS: fprintf(out,
"(size_t)m->num_attention_heads");
break;
244 case CK_DIM_VOCAB: fprintf(out,
"(size_t)m->vocab_size");
break;
252 for (
int i = 0; i < 4; ++i) {
261 if (shape[i].mult != 1) {
262 fprintf(out,
" * %d", shape[i].mult);
264 if (shape[i].div != 1) {
265 fprintf(out,
" / %d", shape[i].div);
277 const char *struct_prefix,
281 fprintf(out,
"%s%s%s_offset = bump_bytes(&off, (", indent, struct_prefix, name);
283 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
288 const char *struct_prefix,
291 const char *dtype_expr)
293 fprintf(out,
"%s%s%s_offset = bump_bytes(&off, ck_dtype_row_bytes(%s, (",
294 indent, struct_prefix, name, dtype_expr);
296 fprintf(out,
")), CACHELINE_BYTES);\n");
301 const char *struct_prefix,
306 fprintf(out,
"%s%s%s_offset = m->training_enabled ? bump_bytes(&off, (", indent, struct_prefix, name);
308 fprintf(out,
") * elem_bytes, CACHELINE_BYTES) : 0;\n");
328 fprintf(out,
" m->%s_offset = m->%s_offset;\n", spec->
name, spec->
alias_of);
333 fprintf(out,
" if (m->rope_theta > 0.0f) {\n");
334 fprintf(out,
" m->%s_offset = bump_bytes(&off, (", spec->
name);
336 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
337 fprintf(out,
" } else {\n");
338 fprintf(out,
" m->%s_offset = 0;\n", spec->
name);
339 fprintf(out,
" }\n");
343 fprintf(out,
" if (m->training_enabled) {\n");
344 fprintf(out,
" m->%s_offset = bump_bytes(&off, (", spec->
name);
346 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
347 fprintf(out,
" } else {\n");
348 fprintf(out,
" m->%s_offset = 0;\n", spec->
name);
349 fprintf(out,
" }\n");
378 fprintf(out,
" if (m->training_enabled) {\n");
379 fprintf(out,
" L->%s_offset = bump_bytes(&off, (", spec->
name);
381 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
382 fprintf(out,
" } else {\n");
383 fprintf(out,
" L->%s_offset = 0;\n", spec->
name);
384 fprintf(out,
" }\n");
410 " if (m->num_layers > 0) {\n"
411 " m->%s_offset = m->layers[m->num_layers - 1].%s_offset;\n"
413 " m->%s_offset = 0;\n"
422 "static void zero_grad(TransformerModel *m)\n"
424 " if (!m || !m->training_enabled) return;\n"
425 " uint8_t *base = m->memory_base;\n"
426 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n");
433 fprintf(out,
" if (m->%s_offset) {\n", spec->
name);
434 fprintf(out,
" memset(base + m->%s_offset, 0, (", spec->
name);
436 fprintf(out,
") * m->elem_bytes);\n");
437 fprintf(out,
" }\n");
441 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
442 " TrulyOptimalLayer *L = &m->layers[layer];\n");
448 fprintf(out,
" if (L->%s_offset) {\n", spec->
name);
449 fprintf(out,
" memset(base + L->%s_offset, 0, (", spec->
name);
451 fprintf(out,
") * m->elem_bytes);\n");
452 fprintf(out,
" }\n");
462 "static void sgd_update(TransformerModel *m, float lr)\n"
464 " if (!m || !m->training_enabled || lr == 0.0f) return;\n"
465 " uint8_t *base = m->memory_base;\n"
466 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n");
477 snprintf(grad_name,
sizeof(grad_name),
"d_%s", spec->
name);
483 " if (m->%s_offset && m->%s_offset) {\n"
484 " float *w = ptr_f32(base, m->%s_offset);\n"
485 " float *g = ptr_f32(base, m->%s_offset);\n"
487 spec->
name, grad_name, spec->
name, grad_name);
491 " for (size_t i = 0; i < count; ++i) {\n"
492 " w[i] -= lr * g[i];\n"
498 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
499 " TrulyOptimalLayer *L = &m->layers[layer];\n");
506 snprintf(grad_name,
sizeof(grad_name),
"d_%s", spec->
name);
512 " if (L->%s_offset && L->%s_offset) {\n"
513 " float *w = ptr_f32(base, L->%s_offset);\n"
514 " float *g = ptr_f32(base, L->%s_offset);\n"
516 spec->
name, grad_name, spec->
name, grad_name);
520 " for (size_t i = 0; i < count; ++i) {\n"
521 " w[i] -= lr * g[i];\n"
536 if (!path || !path[0]) {
539 for (
size_t i = 0; i < *seen_count; ++i) {
540 if (strcmp(seen[i], path) == 0) {
544 if (*seen_count >= seen_cap) {
549 seen[*seen_count] = path;
559 if (strcmp(step->
condition,
"rope_theta") == 0) {
562 if (strcmp(step->
condition,
"rope_theta>0") == 0) {
576 for (
size_t i = 0; i < plan_count; ++i) {
586 const char *src = spec->
sources[s];
617 if (!forward || !out) {
622 "/* Auto-generated skeleton from CKIRGraph.\n"
623 " * This file sketches the structure of the forward and backward\n"
624 " * execution for a decoder-only transformer. It is NOT yet a\n"
625 " * complete, runnable implementation. You can use it as a\n"
626 " * starting point to wire buffers, kernel calls, and memory layout.\n"
629 fprintf(out,
"#include \"ckernel_engine.h\"\n");
630 fprintf(out,
"#include \"ckernel_model.h\"\n");
631 fprintf(out,
"#include \"ckernel_alloc.h\"\n\n");
635 "void run_decoder_forward(TransformerModel *model /*, inputs, etc. */)\n"
637 " for (int layer = 0; layer < model->cfg.num_layers; ++layer) {\n"
638 " /* Forward pass for layer */\n");
640 int nodes_per_layer = 0;
643 for (
int i = 0; i < forward->
num_nodes; ++i) {
651 if (nodes_per_layer <= 0) {
655 fprintf(out,
" /* This layer has %d IR nodes */\n", nodes_per_layer);
657 for (
int i = 0; i < nodes_per_layer; ++i) {
659 fprintf(out,
" // L%%d: %s\n",
op_name(n->
op));
663 if (o > 0) fprintf(out,
", ");
664 fprintf(out,
"L%%d:N%d:%d", n->
id.
node, o);
667 fprintf(out,
" // inputs : [");
668 for (
int j = 0; j < n->
n_inputs; ++j) {
670 if (j > 0) fprintf(out,
", ");
674 fprintf(out,
"L%%d:N%u:%u",
681 " // TODO: bind buffers/weights and call %s kernel here\n\n",
686 " } /* end for layer */\n"
692 "void run_decoder_backward(TransformerModel *model /*, grads, etc. */)\n"
694 " for (int layer = model->cfg.num_layers - 1; layer >= 0; --layer) {\n"
695 " /* Backward pass for layer */\n");
697 int bwd_per_layer = 0;
699 for (
int i = 0; i < backward->
num_nodes; ++i) {
703 if (bwd_per_layer <= 0) bwd_per_layer = backward->
num_nodes;
705 fprintf(out,
" /* This layer has %d backward IR nodes */\n", bwd_per_layer);
707 for (
int i = 0; i < bwd_per_layer; ++i) {
709 fprintf(out,
" // L%%d: %s\n",
op_name(n->
op));
711 " // TODO: wire gradient tensors and call %s kernel here\n\n",
716 " } /* end for layer */\n"
721 "int main(int argc, char **argv)\n"
723 " (void)argc; (void)argv;\n"
724 " TransformerModel model = {0};\n"
725 " model.cfg.num_layers = %d;\n"
726 " model.cfg.hidden_size = %d;\n"
727 " model.cfg.intermediate_size = %d;\n"
728 " model.cfg.num_heads = %d;\n"
729 " model.cfg.num_kv_heads = %d;\n"
730 " model.cfg.vocab_size = %d;\n"
731 " model.cfg.context_window = %d;\n"
732 " model.cfg.rms_norm_eps = %.9g;\n"
733 " model.cfg.rope_theta = %.9g;\n"
734 " layout_transformer_from_ir(&model, NULL); /* TODO: pass IR if needed */\n"
735 " size_t bytes = model.total_bytes;\n"
736 " model.memory_base = (uint8_t *)ck_huge_alloc(bytes);\n"
737 " if (!model.memory_base) {\n"
738 " fprintf(stderr, \"Failed to allocate %%zu bytes for model\\n\", bytes);\n"
741 " // TODO: load weights into model.memory_base based on offsets\n"
742 " run_decoder_forward(&model);\n"
743 " // TODO: run_decoder_backward(&model) when training\n"
744 " ck_huge_free(model.memory_base, bytes);\n"
761 "/* Auto-generated runtime from CKIRGraph.\n"
762 " * This file wires the existing C-Kernel-Engine kernels into a\n"
763 " * decoder-only transformer forward pass.\n"
765 " * Compile (scalar): gcc -O2 generated_model.c $(cat generated_model.c.kernels) -Iinclude -lm -o generated_model\n"
766 " * Compile (AVX-512): gcc -O3 -mavx512f -mfma generated_model.c $(cat generated_model.c.kernels) -Iinclude -lm -o generated_model\n"
770 "#define _GNU_SOURCE\n"
771 "#include <stddef.h>\n"
772 "#include <stdint.h>\n"
773 "#include <stdbool.h>\n"
774 "#include <stdio.h>\n"
775 "#include <stdlib.h>\n"
776 "#include <string.h>\n"
777 "#include <math.h>\n"
778 "#include <errno.h>\n"
779 "#include <sys/types.h>\n"
780 "#include <unistd.h>\n"
781 "#include \"ckernel_engine.h\"\n"
782 "#include \"ckernel_dtype.h\"\n"
783 "#include \"ckernel_orchestration.h\"\n"
784 "#include \"ckernel_alloc.h\"\n\n");
787 "#define CACHELINE_BYTES 64\n"
788 "static size_t align_up_bytes(size_t n, size_t align) {\n"
789 " if (align == 0) return n;\n"
790 " return (n + align - 1) & ~(align - 1);\n"
792 "static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align) {\n"
793 " size_t bytes = elems * elem_bytes;\n"
794 " bytes = align_up_bytes(bytes, align);\n"
795 " return bytes / elem_bytes;\n"
797 "static size_t bump_bytes(size_t *off, size_t bytes, size_t align) {\n"
798 " size_t start = align_up_bytes(*off, align);\n"
799 " *off = start + bytes;\n"
802 "static inline float *ptr_f32(uint8_t *base, size_t offset) {\n"
803 " return (float *)(base + offset);\n"
805 "static inline const float *cptr_f32(const uint8_t *base, size_t offset) {\n"
806 " return (const float *)(base + offset);\n"
810 "static inline uint8_t *ptr_u8(uint8_t *base, size_t offset) {\n"
811 " return base + offset;\n"
813 "static inline const void *cptr_void(const uint8_t *base, size_t offset) {\n"
814 " return (const void *)(base + offset);\n"
822 if (!forward || !runtime_path) {
826 const char *suffix =
".kernels";
827 size_t len = strlen(runtime_path) + strlen(suffix) + 1;
828 char *path = (
char *)malloc(len);
832 snprintf(path, len,
"%s%s", runtime_path, suffix);
834 FILE *f = fopen(path,
"wb");
836 fprintf(stderr,
"ck_codegen_emit_runtime: failed to open %s: %s\n",
837 path, strerror(errno));
843 const char **seen = (
const char **)calloc(seen_cap,
sizeof(
char *));
849 size_t seen_count = 0;
855 emit_unique_source(f,
"src/kernels/embedding_kernels.c", seen, &seen_count, seen_cap);
859 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_q8k.c", seen, &seen_count, seen_cap);
862 emit_unique_source(f,
"src/kernels/gemm_kernels_q4_0.c", seen, &seen_count, seen_cap);
863 emit_unique_source(f,
"src/kernels/gemm_kernels_q4_1.c", seen, &seen_count, seen_cap);
864 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_0.c", seen, &seen_count, seen_cap);
865 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_1.c", seen, &seen_count, seen_cap);
866 emit_unique_source(f,
"src/kernels/gemm_kernels_q8_0.c", seen, &seen_count, seen_cap);
868 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_sse.c", seen, &seen_count, seen_cap);
869 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_avx.c", seen, &seen_count, seen_cap);
870 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_q8k_avx2.c", seen, &seen_count, seen_cap);
871 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_q8k_vnni.c", seen, &seen_count, seen_cap);
872 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_0_sse.c", seen, &seen_count, seen_cap);
873 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_0_sse_v2.c", seen, &seen_count, seen_cap);
874 emit_unique_source(f,
"src/kernels/gemm_kernels_q6k_sse.c", seen, &seen_count, seen_cap);
875 emit_unique_source(f,
"src/kernels/quantize_row_q8_k_sse.c", seen, &seen_count, seen_cap);
880 emit_unique_source(f,
"src/kernels/gemm_fused_kernels.c", seen, &seen_count, seen_cap);
883 emit_unique_source(f,
"src/kernels/attention_decode_fused.c", seen, &seen_count, seen_cap);
884 emit_unique_source(f,
"src/kernels/attention_flash_true.c", seen, &seen_count, seen_cap);
912 fprintf(stderr,
"[ck_codegen] kernels manifest written to %s\n", path);
921 "\n/* ═══════════════════════════════════════════════════════════════\n"
922 " * C-Kernel-Engine Library API (for dlopen)\n"
923 " * ═══════════════════════════════════════════════════════════════ */\n\n"
925 "#define CK_EXPORT __declspec(dllexport)\n"
927 "#define CK_EXPORT __attribute__((visibility(\"default\")))\n"
931 " int hidden_size;\n"
932 " int intermediate_size;\n"
934 " int num_kv_heads;\n"
936 " int context_window;\n"
937 " float rms_norm_eps;\n"
938 " float rope_theta;\n"
940 "static TransformerModel g_model = {0};\n"
941 "static int g_initialized = 0;\n"
942 "static int g_fuse_swiglu_decode = -2;\n"
943 "static int g_fuse_attn_decode = -2;\n\n"
944 "static int ck_fuse_swiglu_decode_mode(void)\n"
946 " if (g_fuse_swiglu_decode != -2) return g_fuse_swiglu_decode;\n"
947 " const char *env = getenv(\"CK_FUSE_SWIGLU_DECODE\");\n"
948 " if (!env || !env[0]) {\n"
949 " g_fuse_swiglu_decode = -1; /* auto */\n"
950 " return g_fuse_swiglu_decode;\n"
952 " if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' || env[0] == 'f' || env[0] == 'F') {\n"
953 " g_fuse_swiglu_decode = 0;\n"
955 " g_fuse_swiglu_decode = 1;\n"
957 " return g_fuse_swiglu_decode;\n"
961 "static int ck_fuse_attn_decode_mode(void)\n"
963 " if (g_fuse_attn_decode != -2) return g_fuse_attn_decode;\n"
964 " const char *env = getenv(\"CK_FUSE_ATTN_DECODE\");\n"
965 " if (!env || !env[0]) {\n"
966 " g_fuse_attn_decode = -1; /* auto */\n"
967 " return g_fuse_attn_decode;\n"
969 " if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' || env[0] == 'f' || env[0] == 'F') {\n"
970 " g_fuse_attn_decode = 0;\n"
972 " g_fuse_attn_decode = 1;\n"
974 " return g_fuse_attn_decode;\n"
978 "static int run_model_decode(TransformerModel *m, int32_t token)\n"
980 " if (!m || !m->memory_base) return -1;\n"
981 " /* KV-cache decode is an inference-only fast path; training uses the full forward/backward graph. */\n"
982 " if (m->training_enabled) return -4;\n"
983 " if (!m->kv_cache_enabled) return -2;\n"
985 " int cache_cap = m->kv_cache_capacity > 0 ? m->kv_cache_capacity : m->context_window;\n"
986 " if (cache_cap > m->context_window) cache_cap = m->context_window;\n"
987 " int t = m->kv_cache_tokens;\n"
988 " if (t < 0) t = 0;\n"
989 " if (t >= cache_cap) return -3;\n"
991 " embed_token_at(m, token, t);\n"
993 " uint8_t *base = m->memory_base;\n"
994 " float *current = ptr_f32(base, m->embedded_input_offset);\n"
995 " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
996 " int fuse_swiglu_mode = ck_fuse_swiglu_decode_mode();\n"
997 " int use_fused_swiglu = 0;\n"
998 " if (fuse_swiglu_mode > 0) {\n"
999 " use_fused_swiglu = 1;\n"
1000 " } else if (fuse_swiglu_mode == 0) {\n"
1001 " use_fused_swiglu = 0;\n"
1002 " } else if (!ck_strict_parity_enabled()) {\n"
1003 " use_fused_swiglu = 1;\n"
1005 " int fuse_attn_mode = ck_fuse_attn_decode_mode();\n"
1006 " int use_fused_attn = 0;\n"
1007 " if (fuse_attn_mode > 0) {\n"
1008 " use_fused_attn = 1;\n"
1009 " } else if (fuse_attn_mode == 0) {\n"
1010 " use_fused_attn = 0;\n"
1011 " } else if (!ck_strict_parity_enabled()) {\n"
1012 " use_fused_attn = 1;\n"
1014 " if (m->weights_quantized) {\n"
1015 " use_fused_swiglu = 0;\n"
1016 " use_fused_attn = 0;\n"
1019 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1020 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1021 " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1022 " CKLayerForwardParamsQ4K p = {0};\n"
1023 " p.tokens = cache_cap;\n"
1024 " p.embed_dim = m->embed_dim;\n"
1025 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1026 " p.num_heads = m->num_attention_heads;\n"
1027 " p.num_kv_heads = m->num_kv_heads;\n"
1028 " p.head_dim = m->head_dim;\n"
1029 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1030 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1031 " p.intermediate_dim = m->intermediate_size;\n"
1032 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1033 " p.eps = m->rms_norm_eps;\n"
1034 " p.rope_pos_offset = t;\n"
1035 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1036 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1037 " p.input = current;\n"
1038 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1039 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1040 " p.wq = cptr_void(base, L->wq_offset);\n"
1041 " p.bq = cptr_f32(base, L->bq_offset);\n"
1042 " p.wk = cptr_void(base, L->wk_offset);\n"
1043 " p.bk = cptr_f32(base, L->bk_offset);\n"
1044 " p.wv = cptr_void(base, L->wv_offset);\n"
1045 " p.bv = cptr_f32(base, L->bv_offset);\n"
1046 " p.wo = cptr_void(base, L->wo_offset);\n"
1047 " p.bo = cptr_f32(base, L->bo_offset);\n"
1048 " p.w1 = cptr_void(base, L->w1_offset);\n"
1049 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1050 " p.w2 = cptr_void(base, L->w2_offset);\n"
1051 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1052 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1053 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1054 " p.k = ptr_f32(base, L->k_offset);\n"
1055 " p.v = ptr_f32(base, L->v_offset);\n"
1056 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1057 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1058 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1059 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1060 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1061 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1062 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1063 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1064 " p.output = ptr_f32(base, L->output_offset);\n"
1066 " ck_layer_forward_rmsnorm_swiglu_decode_q4_k(&p, t, cache_cap);\n"
1067 " } else if (m->weights_quantized) {\n"
1068 " CKLayerForwardParamsQ4K p = {0};\n"
1069 " p.tokens = cache_cap;\n"
1070 " p.embed_dim = m->embed_dim;\n"
1071 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1072 " p.num_heads = m->num_attention_heads;\n"
1073 " p.num_kv_heads = m->num_kv_heads;\n"
1074 " p.head_dim = m->head_dim;\n"
1075 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1076 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1077 " p.intermediate_dim = m->intermediate_size;\n"
1078 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1079 " p.eps = m->rms_norm_eps;\n"
1080 " p.rope_pos_offset = t;\n"
1081 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1082 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1083 " p.input = current;\n"
1084 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1085 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1086 " p.wq = cptr_void(base, L->wq_offset);\n"
1087 " p.bq = cptr_f32(base, L->bq_offset);\n"
1088 " p.wk = cptr_void(base, L->wk_offset);\n"
1089 " p.bk = cptr_f32(base, L->bk_offset);\n"
1090 " p.wv = cptr_void(base, L->wv_offset);\n"
1091 " p.bv = cptr_f32(base, L->bv_offset);\n"
1092 " p.wo = cptr_void(base, L->wo_offset);\n"
1093 " p.bo = cptr_f32(base, L->bo_offset);\n"
1094 " p.w1 = cptr_void(base, L->w1_offset);\n"
1095 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1096 " p.w2 = cptr_void(base, L->w2_offset);\n"
1097 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1098 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1099 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1100 " p.k = ptr_f32(base, L->k_offset);\n"
1101 " p.v = ptr_f32(base, L->v_offset);\n"
1102 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1103 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1104 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1105 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1106 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1107 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1108 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1109 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1110 " p.output = ptr_f32(base, L->output_offset);\n"
1111 " p.wq_dtype = L->wq_dtype;\n"
1112 " p.wk_dtype = L->wk_dtype;\n"
1113 " p.wv_dtype = L->wv_dtype;\n"
1114 " p.wo_dtype = L->wo_dtype;\n"
1115 " p.w1_dtype = L->w1_dtype;\n"
1116 " p.w2_dtype = L->w2_dtype;\n"
1118 " ck_layer_forward_rmsnorm_swiglu_decode_quant(&p, t, cache_cap);\n"
1120 " CKLayerForwardParams p = {0};\n"
1121 " p.tokens = cache_cap;\n"
1122 " p.embed_dim = m->embed_dim;\n"
1123 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1124 " p.num_heads = m->num_attention_heads;\n"
1125 " p.num_kv_heads = m->num_kv_heads;\n"
1126 " p.head_dim = m->head_dim;\n"
1127 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1128 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1129 " p.intermediate_dim = m->intermediate_size;\n"
1130 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1131 " p.eps = m->rms_norm_eps;\n"
1132 " p.rope_pos_offset = t;\n"
1133 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1134 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1135 " p.input = current;\n"
1136 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1137 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1138 " p.wq = cptr_f32(base, L->wq_offset);\n"
1139 " p.bq = cptr_f32(base, L->bq_offset);\n"
1140 " p.wk = cptr_f32(base, L->wk_offset);\n"
1141 " p.bk = cptr_f32(base, L->bk_offset);\n"
1142 " p.wv = cptr_f32(base, L->wv_offset);\n"
1143 " p.bv = cptr_f32(base, L->bv_offset);\n"
1144 " p.wo = cptr_f32(base, L->wo_offset);\n"
1145 " p.bo = cptr_f32(base, L->bo_offset);\n"
1146 " p.w1 = cptr_f32(base, L->w1_offset);\n"
1147 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1148 " p.w2 = cptr_f32(base, L->w2_offset);\n"
1149 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1150 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1151 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1152 " p.k = ptr_f32(base, L->k_offset);\n"
1153 " p.v = ptr_f32(base, L->v_offset);\n"
1154 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1155 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1156 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1157 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1158 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1159 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1160 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1161 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1162 " p.output = ptr_f32(base, L->output_offset);\n"
1164 " if (use_fused_attn) {\n"
1165 " if (use_fused_swiglu) {\n"
1166 " ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(&p, t, cache_cap);\n"
1168 " ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(&p, t, cache_cap);\n"
1170 " } else if (use_fused_swiglu) {\n"
1171 " ck_layer_forward_rmsnorm_swiglu_decode_fused(&p, t, cache_cap);\n"
1173 " ck_layer_forward_rmsnorm_swiglu_decode(&p, t, cache_cap);\n"
1176 " current = ptr_f32(base, L->output_offset);\n"
1179 " int V = m->vocab_size;\n"
1180 " int D = m->embed_dim;\n"
1181 " int aligned_D = (int)m->aligned_embed_dim;\n"
1182 " float *final_in = current + (size_t)t * aligned_D;\n"
1183 " float *final_out = ptr_f32(base, m->final_output_offset) + (size_t)t * aligned_D;\n"
1184 " float *final_rstd = ptr_f32(base, m->final_ln_rstd_offset) + (size_t)t;\n"
1186 " rmsnorm_forward(final_in,\n"
1187 " cptr_f32(base, m->final_ln_weight_offset),\n"
1193 " m->rms_norm_eps);\n"
1195 " float *logits_row = ptr_f32(base, m->logits_offset) + (size_t)t * (size_t)V;\n"
1196 " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1197 " gemm_nt_q4_k(final_out,\n"
1198 " cptr_void(base, m->lm_head_weight_offset),\n"
1204 " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1205 " gemm_nt_q6_k(final_out,\n"
1206 " cptr_void(base, m->lm_head_weight_offset),\n"
1213 " lm_head_forward(final_out,\n"
1214 " cptr_f32(base, m->lm_head_weight_offset),\n"
1223 " m->kv_cache_tokens = t + 1;\n"
1224 " m->active_tokens = m->kv_cache_tokens;\n"
1230 "CK_EXPORT int ck_model_init(const char *weights_path)\n"
1232 " if (g_initialized) return 0;\n"
1233 " memcpy(g_model.magic, \"BUMPWGT3\", 8);\n"
1234 " g_model.version = 3;\n"
1235 " g_model.model_type = 0;\n"
1236 " g_model.num_layers = %d;\n"
1237 " g_model.embed_dim = %d;\n"
1238 " g_model.intermediate_size = %d;\n"
1239 " g_model.num_attention_heads = %d;\n"
1240 " g_model.num_kv_heads = %d;\n"
1241 " g_model.vocab_size = %d;\n"
1242 " g_model.context_window = %d;\n"
1243 " g_model.rms_norm_eps = (float)%.9g;\n"
1244 " g_model.rope_theta = (float)%.9g;\n"
1245 " g_model.num_cores = 1;\n"
1246 " g_model.task_type = TASK_LM;\n"
1247 " g_model.weight_dtype = CK_DT_FP32;\n"
1248 " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
1250 " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
1251 " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
1252 " g_model.weight_dtype = CK_DT_Q4_K;\n"
1253 " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
1254 " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
1255 " g_model.weight_dtype = CK_DT_Q6_K;\n"
1258 " init_weight_dtypes_uniform(&g_model, g_model.weight_dtype);\n"
1259 " refresh_weight_flags(&g_model);\n"
1260 " /* Check env var to pre-allocate gradient buffers for training */\n"
1261 " const char *train_env = getenv(\"CK_ENABLE_TRAINING\");\n"
1262 " if (train_env && (train_env[0] == '1' || train_env[0] == 'y' || train_env[0] == 'Y')) {\n"
1263 " g_model.training_enabled = true;\n"
1264 " g_model.learning_rate = 1e-4f;\n"
1266 " if (weights_path) {\n"
1267 " int dtype_rc = load_weight_dtypes(weights_path, &g_model);\n"
1268 " if (dtype_rc < 0) {\n"
1269 " fprintf(stderr, \"Failed to read weight dtype table from %%s\\n\", weights_path);\n"
1273 " if (g_model.training_enabled && g_model.weights_quantized) {\n"
1274 " fprintf(stderr, \"Quantized weights are inference-only; disable training for this model\\n\");\n"
1277 " g_model.kv_cache_enabled = false;\n"
1278 " g_model.kv_cache_capacity = g_model.context_window;\n"
1279 " g_model.kv_cache_tokens = 0;\n"
1280 " if (layout_model(&g_model) != 0) return -1;\n"
1281 " if (weights_path) {\n"
1282 " if (load_model_weights(weights_path, &g_model) != 0) return -2;\n"
1284 " g_initialized = 1;\n"
1299 "CK_EXPORT void ck_model_get_info(CKModelInfo *info)\n"
1301 " if (!info) return;\n"
1302 " info->num_layers = g_model.num_layers;\n"
1303 " info->hidden_size = g_model.embed_dim;\n"
1304 " info->intermediate_size = g_model.intermediate_size;\n"
1305 " info->num_heads = g_model.num_attention_heads;\n"
1306 " info->num_kv_heads = g_model.num_kv_heads;\n"
1307 " info->vocab_size = g_model.vocab_size;\n"
1308 " info->context_window = g_model.context_window;\n"
1309 " info->rms_norm_eps = g_model.rms_norm_eps;\n"
1310 " info->rope_theta = g_model.rope_theta;\n"
1315 "CK_EXPORT int ck_model_embed_tokens(const int32_t *tokens, int num_tokens)\n"
1317 " if (!g_initialized) return -1;\n"
1318 " int cap = g_model.context_window;\n"
1319 " if (g_model.kv_cache_enabled && g_model.kv_cache_capacity > 0 && g_model.kv_cache_capacity < cap) {\n"
1320 " cap = g_model.kv_cache_capacity;\n"
1322 " if (num_tokens > cap) num_tokens = cap;\n"
1323 " if (num_tokens < 1) num_tokens = 1;\n"
1324 " g_model.active_tokens = num_tokens;\n"
1325 " if (g_model.kv_cache_enabled && !g_model.training_enabled) {\n"
1326 " g_model.kv_cache_tokens = 0;\n"
1328 " embed_tokens(&g_model, tokens, num_tokens);\n"
1334 "CK_EXPORT int ck_model_forward(float *logits_out)\n"
1336 " if (!g_initialized) return -1;\n"
1337 " run_model_forward(&g_model);\n"
1338 " if (g_model.kv_cache_enabled && !g_model.training_enabled) {\n"
1339 " g_model.kv_cache_tokens = g_model.active_tokens;\n"
1341 " if (logits_out && g_model.vocab_size > 0) {\n"
1342 " size_t n = (size_t)g_model.active_tokens * (size_t)g_model.vocab_size;\n"
1343 " memcpy(logits_out, ptr_f32(g_model.memory_base, g_model.logits_offset), n * sizeof(float));\n"
1350 "CK_EXPORT int ck_model_kv_cache_enable(int capacity)\n"
1352 " if (!g_initialized) return -1;\n"
1353 " if (g_model.training_enabled) return -4;\n"
1354 " g_model.kv_cache_enabled = true;\n"
1355 " int cap = capacity;\n"
1356 " if (cap <= 0 || cap > g_model.context_window) cap = g_model.context_window;\n"
1357 " g_model.kv_cache_capacity = cap;\n"
1358 " g_model.kv_cache_tokens = 0;\n"
1359 " g_model.active_tokens = 0;\n"
1362 "CK_EXPORT void ck_model_kv_cache_reset(void)\n"
1364 " if (!g_initialized) return;\n"
1365 " g_model.kv_cache_tokens = 0;\n"
1366 " g_model.active_tokens = 0;\n"
1368 "CK_EXPORT int ck_model_kv_cache_get_tokens(void)\n"
1370 " return g_initialized ? g_model.kv_cache_tokens : 0;\n"
1372 "CK_EXPORT int ck_model_decode(int32_t token, float *logits_out)\n"
1374 " if (!g_initialized) return -1;\n"
1375 " if (g_model.training_enabled) return -4;\n"
1376 " int ret = run_model_decode(&g_model, token);\n"
1377 " if (ret != 0) return ret;\n"
1378 " if (logits_out && g_model.vocab_size > 0) {\n"
1379 " int t = g_model.active_tokens - 1;\n"
1380 " memcpy(logits_out,\n"
1381 " ptr_f32(g_model.memory_base, g_model.logits_offset) + (size_t)t * (size_t)g_model.vocab_size,\n"
1382 " (size_t)g_model.vocab_size * sizeof(float));\n"
1389 "CK_EXPORT float* ck_model_get_logits(void)\n"
1391 " if (!g_initialized) return NULL;\n"
1392 " return ptr_f32(g_model.memory_base, g_model.logits_offset);\n"
1397 "CK_EXPORT int ck_model_backward(const int32_t *tokens, const int32_t *targets, float *loss_out)\n"
1399 " if (!g_initialized) return -1;\n"
1400 " return run_model_backward(&g_model, tokens, targets, loss_out);\n"
1405 "CK_EXPORT void ck_model_free(void)\n"
1407 " if (!g_initialized) return;\n"
1408 " if (g_model.memory_base) ck_huge_free(g_model.memory_base, g_model.total_bytes);\n"
1409 " if (g_model.layers) free(g_model.layers);\n"
1410 " memset(&g_model, 0, sizeof(g_model));\n"
1411 " g_initialized = 0;\n"
1416 "CK_EXPORT int ck_model_get_context_window(void) { return g_initialized ? g_model.context_window : 0; }\n"
1417 "CK_EXPORT int ck_model_get_vocab_size(void) { return g_initialized ? g_model.vocab_size : 0; }\n"
1418 "CK_EXPORT int ck_model_get_hidden_size(void) { return g_initialized ? g_model.embed_dim : 0; }\n"
1419 "CK_EXPORT int ck_model_get_active_tokens(void) { return g_initialized ? g_model.active_tokens : 0; }\n"
1420 "CK_EXPORT int ck_model_is_training_enabled(void) { return g_initialized ? g_model.training_enabled : 0; }\n"
1421 "CK_EXPORT void ck_model_set_learning_rate(float lr) { if (g_initialized) g_model.learning_rate = lr; }\n"
1422 "CK_EXPORT float ck_model_get_learning_rate(void) { return g_initialized ? g_model.learning_rate : 0.0f; }\n\n"
1423 "CK_EXPORT int ck_model_enable_training(float learning_rate)\n"
1425 " if (!g_initialized) return -1;\n"
1426 " g_model.training_enabled = true;\n"
1427 " g_model.learning_rate = learning_rate;\n"
1430 "CK_EXPORT void ck_model_disable_training(void)\n"
1432 " if (g_initialized) g_model.training_enabled = false;\n"
1434 "CK_EXPORT void ck_model_optimizer_step(void)\n"
1436 " if (!g_initialized || !g_model.training_enabled) return;\n"
1437 " sgd_update(&g_model, g_model.learning_rate);\n"
1443 if (!forward || !path) {
1450 FILE *out = fopen(path,
"wb");
1452 fprintf(stderr,
"ck_codegen_emit_runtime: failed to open %s: %s\n",
1453 path, strerror(errno));
1465 " TASK_SEQ_CLS = 1\n"
1468 " OPTIMIZER_SGD = 0,\n"
1469 " OPTIMIZER_ADAM = 1\n"
1470 "} OptimizerType;\n\n"
1471 "typedef struct {\n"
1472 " size_t total_gradient_floats;\n"
1473 "} GradientStorage;\n\n");
1479 "static int ensure_layers_allocated(TransformerModel *m)\n"
1481 " if (!m) return -1;\n"
1482 " if (!m->layers && m->num_layers > 0) {\n"
1483 " m->layers = (TrulyOptimalLayer *)calloc((size_t)m->num_layers, sizeof(TrulyOptimalLayer));\n"
1484 " if (!m->layers) return -1;\n"
1488 "static void init_weight_dtypes_uniform(TransformerModel *m, CKDataType dt)\n"
1490 " if (!m) return;\n"
1491 " m->token_emb_dtype = dt;\n"
1492 " m->lm_head_weight_dtype = dt;\n"
1493 " m->pos_emb_dtype = CK_DT_FP32;\n"
1494 " if (ensure_layers_allocated(m) != 0) return;\n"
1495 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1496 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1497 " L->wq_dtype = dt;\n"
1498 " L->wk_dtype = dt;\n"
1499 " L->wv_dtype = dt;\n"
1500 " L->wo_dtype = dt;\n"
1501 " L->w1_dtype = dt;\n"
1502 " L->w2_dtype = dt;\n"
1505 "static void refresh_weight_flags(TransformerModel *m)\n"
1507 " if (!m) return;\n"
1508 " CKDataType base = m->token_emb_dtype;\n"
1510 " int quant = ck_dtype_is_quantized(base);\n"
1511 " if (m->lm_head_weight_dtype != base) mixed = 1;\n"
1512 " if (ck_dtype_is_quantized(m->lm_head_weight_dtype)) quant = 1;\n"
1513 " if (m->layers) {\n"
1514 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1515 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1516 " if (L->wq_dtype != base || L->wk_dtype != base || L->wv_dtype != base ||\n"
1517 " L->wo_dtype != base || L->w1_dtype != base || L->w2_dtype != base) {\n"
1520 " if (ck_dtype_is_quantized(L->wq_dtype) || ck_dtype_is_quantized(L->wk_dtype) ||\n"
1521 " ck_dtype_is_quantized(L->wv_dtype) || ck_dtype_is_quantized(L->wo_dtype) ||\n"
1522 " ck_dtype_is_quantized(L->w1_dtype) || ck_dtype_is_quantized(L->w2_dtype)) {\n"
1527 " m->weights_mixed = mixed ? true : false;\n"
1528 " m->weights_quantized = quant ? true : false;\n"
1530 " m->weight_dtype = base;\n"
1533 "static int load_weight_dtypes(const char *path, TransformerModel *m)\n"
1535 " if (!path || !m) return -1;\n"
1536 " FILE *f = fopen(path, \"rb\");\n"
1537 " if (!f) return -1;\n"
1539 " if (fread(magic, 1, 8, f) != 8) {\n"
1543 " if (memcmp(magic, \"BUMPWGT3\", 8) != 0) {\n"
1547 " uint32_t version = 0;\n"
1548 " if (fread(&version, sizeof(uint32_t), 1, f) != 1) {\n"
1552 " if (version < 3) {\n"
1556 " if (fseek(f, 128, SEEK_SET) != 0) {\n"
1560 " uint32_t dtype_len = 0;\n"
1561 " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) {\n"
1565 " if (dtype_len == 0) {\n"
1569 " uint8_t *dtype_buf = (uint8_t *)malloc(dtype_len);\n"
1570 " if (!dtype_buf) {\n"
1574 " if (fread(dtype_buf, 1, dtype_len, f) != dtype_len) {\n"
1575 " free(dtype_buf);\n"
1581 " size_t expected = (size_t)m->num_layers * 14u + 4u;\n"
1582 " if (dtype_len != expected) {\n"
1583 " free(dtype_buf);\n"
1586 " if (ensure_layers_allocated(m) != 0) {\n"
1587 " free(dtype_buf);\n"
1591 " size_t idx = 0;\n"
1592 " CKDataType token_dt = (CKDataType)dtype_buf[idx++];\n"
1593 " CKDataType pos_dt = (CKDataType)dtype_buf[idx++];\n"
1594 " if (pos_dt != CK_DT_FP32) {\n"
1595 " free(dtype_buf);\n"
1598 " if (token_dt != CK_DT_FP32 && token_dt != CK_DT_Q4_K && token_dt != CK_DT_Q6_K) {\n"
1599 " free(dtype_buf);\n"
1602 " m->token_emb_dtype = token_dt;\n"
1603 " m->lm_head_weight_dtype = token_dt;\n"
1604 " m->pos_emb_dtype = pos_dt;\n"
1606 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1607 " CKDataType ln1_dt = (CKDataType)dtype_buf[idx++];\n"
1608 " CKDataType ln2_dt = (CKDataType)dtype_buf[idx++];\n"
1609 " CKDataType wq_dt = (CKDataType)dtype_buf[idx++];\n"
1610 " CKDataType bq_dt = (CKDataType)dtype_buf[idx++];\n"
1611 " CKDataType wk_dt = (CKDataType)dtype_buf[idx++];\n"
1612 " CKDataType bk_dt = (CKDataType)dtype_buf[idx++];\n"
1613 " CKDataType wv_dt = (CKDataType)dtype_buf[idx++];\n"
1614 " CKDataType bv_dt = (CKDataType)dtype_buf[idx++];\n"
1615 " CKDataType wo_dt = (CKDataType)dtype_buf[idx++];\n"
1616 " CKDataType bo_dt = (CKDataType)dtype_buf[idx++];\n"
1617 " CKDataType w1_dt = (CKDataType)dtype_buf[idx++];\n"
1618 " CKDataType b1_dt = (CKDataType)dtype_buf[idx++];\n"
1619 " CKDataType w2_dt = (CKDataType)dtype_buf[idx++];\n"
1620 " CKDataType b2_dt = (CKDataType)dtype_buf[idx++];\n"
1622 " if (ln1_dt != CK_DT_FP32 || ln2_dt != CK_DT_FP32 ||\n"
1623 " bq_dt != CK_DT_FP32 || bk_dt != CK_DT_FP32 ||\n"
1624 " bv_dt != CK_DT_FP32 || bo_dt != CK_DT_FP32 ||\n"
1625 " b1_dt != CK_DT_FP32 || b2_dt != CK_DT_FP32) {\n"
1626 " free(dtype_buf);\n"
1629 " if ((wq_dt != CK_DT_FP32 && wq_dt != CK_DT_Q4_K && wq_dt != CK_DT_Q6_K) ||\n"
1630 " (wk_dt != CK_DT_FP32 && wk_dt != CK_DT_Q4_K && wk_dt != CK_DT_Q6_K) ||\n"
1631 " (wv_dt != CK_DT_FP32 && wv_dt != CK_DT_Q4_K && wv_dt != CK_DT_Q6_K) ||\n"
1632 " (wo_dt != CK_DT_FP32 && wo_dt != CK_DT_Q4_K && wo_dt != CK_DT_Q6_K) ||\n"
1633 " (w1_dt != CK_DT_FP32 && w1_dt != CK_DT_Q4_K && w1_dt != CK_DT_Q6_K) ||\n"
1634 " (w2_dt != CK_DT_FP32 && w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K)) {\n"
1635 " free(dtype_buf);\n"
1639 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1640 " L->wq_dtype = wq_dt;\n"
1641 " L->wk_dtype = wk_dt;\n"
1642 " L->wv_dtype = wv_dt;\n"
1643 " L->wo_dtype = wo_dt;\n"
1644 " L->w1_dtype = w1_dt;\n"
1645 " L->w2_dtype = w2_dt;\n"
1648 " CKDataType final_norm_dt = (CKDataType)dtype_buf[idx++];\n"
1649 " CKDataType final_bias_dt = (CKDataType)dtype_buf[idx++];\n"
1650 " free(dtype_buf);\n"
1651 " if (final_norm_dt != CK_DT_FP32 || final_bias_dt != CK_DT_FP32) {\n"
1655 " refresh_weight_flags(m);\n"
1659 "static int layout_model(TransformerModel *m)\n"
1661 " if (!m) return -1;\n"
1662 " if (m->num_attention_heads <= 0 || m->embed_dim <= 0) return -1;\n"
1663 " if (m->num_kv_heads <= 0) m->num_kv_heads = m->num_attention_heads;\n"
1664 " if (m->num_attention_heads %% m->num_kv_heads != 0) return -1;\n"
1665 " if (m->context_window <= 0) m->context_window = 1;\n"
1666 " if (m->vocab_size <= 0) m->vocab_size = 1;\n"
1667 " if (m->intermediate_size <= 0) return -1;\n"
1668 " m->head_dim = m->embed_dim / m->num_attention_heads;\n"
1669 " if (m->rms_norm_eps <= 0.0f) m->rms_norm_eps = 1e-5f;\n"
1670 " if (m->rope_theta < 0.0f) m->rope_theta = 0.0f;\n"
1671 " if (m->rope_theta > 0.0f && (m->head_dim %% 2 != 0)) return -1;\n"
1672 " if (m->elem_bytes == 0) m->elem_bytes = sizeof(float);\n"
1673 " size_t elem_bytes = m->elem_bytes;\n"
1674 " m->aligned_embed_dim = align_up_elems((size_t)m->embed_dim, elem_bytes, CACHELINE_BYTES);\n"
1675 " m->aligned_head_dim = align_up_elems((size_t)m->head_dim, elem_bytes, CACHELINE_BYTES);\n"
1676 " m->aligned_attn_context_window = align_up_elems((size_t)m->context_window, elem_bytes, CACHELINE_BYTES);\n"
1677 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, elem_bytes, CACHELINE_BYTES);\n"
1678 " if (ensure_layers_allocated(m) != 0) return -1;\n"
1679 " if (m->weights_quantized) {\n"
1680 " /* K-quant weights require K dimension to be a multiple of 256. */\n"
1681 " if ((m->aligned_embed_dim %% 256) != 0) return -1;\n"
1682 " if ((aligned_intermediate_dim %% 256) != 0) return -1;\n"
1683 " int wo_quant = 0;\n"
1684 " if (m->layers) {\n"
1685 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1686 " if (ck_dtype_is_quantized(m->layers[layer].wo_dtype)) {\n"
1692 " if (wo_quant && (size_t)m->num_attention_heads * m->aligned_head_dim != m->aligned_embed_dim) return -1;\n"
1695 " if (m->num_cores <= 0) m->num_cores = 1;\n"
1696 " m->tokens_per_core = (m->context_window + m->num_cores - 1) / m->num_cores;\n"
1698 " size_t off = 0;\n");
1701 " m->layers_start_offset = off;\n"
1703 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1704 " TrulyOptimalLayer *L = &m->layers[layer];\n");
1712 " if (m->num_layers > 1) {\n"
1713 " m->layer_stride = m->layers[1].%s_offset - m->layers[0].%s_offset;\n"
1715 " m->layer_stride = 0;\n"
1717 stride_field, stride_field);
1721 " m->total_bytes = align_up_bytes(off, CACHELINE_BYTES);\n"
1722 " m->memory_base = (uint8_t *)ck_huge_alloc(m->total_bytes);\n"
1723 " if (!m->memory_base) return -1;\n"
1724 " if (m->rope_theta > 0.0f) {\n"
1725 " rope_precompute_cache(ptr_f32(m->memory_base, m->rope_cos_cache_offset),\n"
1726 " ptr_f32(m->memory_base, m->rope_sin_cache_offset),\n"
1727 " m->context_window,\n"
1729 " m->rope_theta);\n"
1735 "static void lm_head_forward(const float *hidden,\n"
1736 " const float *weights,\n"
1738 " int T, int V, int D, int aligned_D);\n"
1739 "static void lm_head_backward(const float *hidden,\n"
1740 " const float *weights,\n"
1741 " const float *d_logits,\n"
1742 " float *d_hidden,\n"
1743 " float *d_weights,\n"
1744 " int T, int V, int D, int aligned_D);\n"
1745 "static void softmax_cross_entropy(const float *logits,\n"
1746 " const int32_t *targets,\n"
1748 " float *d_logits,\n"
1749 " float *loss_out);\n\n");
1752 "static void run_model_forward(TransformerModel *m)\n"
1754 " uint8_t *base = m->memory_base;\n"
1755 " float *current = ptr_f32(base, m->embedded_input_offset);\n"
1756 " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
1757 " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
1758 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1759 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1760 " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1761 " CKLayerForwardParamsQ4K p = {0};\n"
1763 " p.embed_dim = m->embed_dim;\n"
1764 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1765 " p.num_heads = m->num_attention_heads;\n"
1766 " p.num_kv_heads = m->num_kv_heads;\n"
1767 " p.head_dim = m->head_dim;\n"
1768 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1769 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1770 " p.intermediate_dim = m->intermediate_size;\n"
1771 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1772 " p.eps = m->rms_norm_eps;\n"
1773 " p.rope_pos_offset = 0;\n"
1774 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1775 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1776 " p.input = current;\n"
1777 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1778 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1779 " p.wq = cptr_void(base, L->wq_offset);\n"
1780 " p.bq = cptr_f32(base, L->bq_offset);\n"
1781 " p.wk = cptr_void(base, L->wk_offset);\n"
1782 " p.bk = cptr_f32(base, L->bk_offset);\n"
1783 " p.wv = cptr_void(base, L->wv_offset);\n"
1784 " p.bv = cptr_f32(base, L->bv_offset);\n"
1785 " p.wo = cptr_void(base, L->wo_offset);\n"
1786 " p.bo = cptr_f32(base, L->bo_offset);\n"
1787 " p.w1 = cptr_void(base, L->w1_offset);\n"
1788 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1789 " p.w2 = cptr_void(base, L->w2_offset);\n"
1790 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1791 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1792 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1793 " p.q = ptr_f32(base, L->q_offset);\n"
1794 " p.k = ptr_f32(base, L->k_offset);\n"
1795 " p.v = ptr_f32(base, L->v_offset);\n"
1796 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1797 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1798 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1799 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1800 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1801 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1802 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1803 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1804 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1805 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1806 " p.output = ptr_f32(base, L->output_offset);\n"
1807 " ck_layer_forward_rmsnorm_swiglu_q4_k(&p);\n"
1808 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1809 " kv_cache_repack_head_major_inplace(p.k,\n"
1810 " p.num_kv_heads,\n"
1812 " m->kv_cache_capacity,\n"
1813 " p.aligned_head_dim);\n"
1814 " kv_cache_repack_head_major_inplace(p.v,\n"
1815 " p.num_kv_heads,\n"
1817 " m->kv_cache_capacity,\n"
1818 " p.aligned_head_dim);\n"
1820 " current = p.output;\n"
1821 " } else if (m->weights_quantized) {\n"
1822 " CKLayerForwardParamsQ4K p = {0};\n"
1824 " p.embed_dim = m->embed_dim;\n"
1825 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1826 " p.num_heads = m->num_attention_heads;\n"
1827 " p.num_kv_heads = m->num_kv_heads;\n"
1828 " p.head_dim = m->head_dim;\n"
1829 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1830 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1831 " p.intermediate_dim = m->intermediate_size;\n"
1832 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1833 " p.eps = m->rms_norm_eps;\n"
1834 " p.rope_pos_offset = 0;\n"
1835 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1836 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1837 " p.input = current;\n"
1838 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1839 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1840 " p.wq = cptr_void(base, L->wq_offset);\n"
1841 " p.bq = cptr_f32(base, L->bq_offset);\n"
1842 " p.wk = cptr_void(base, L->wk_offset);\n"
1843 " p.bk = cptr_f32(base, L->bk_offset);\n"
1844 " p.wv = cptr_void(base, L->wv_offset);\n"
1845 " p.bv = cptr_f32(base, L->bv_offset);\n"
1846 " p.wo = cptr_void(base, L->wo_offset);\n"
1847 " p.bo = cptr_f32(base, L->bo_offset);\n"
1848 " p.w1 = cptr_void(base, L->w1_offset);\n"
1849 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1850 " p.w2 = cptr_void(base, L->w2_offset);\n"
1851 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1852 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1853 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1854 " p.q = ptr_f32(base, L->q_offset);\n"
1855 " p.k = ptr_f32(base, L->k_offset);\n"
1856 " p.v = ptr_f32(base, L->v_offset);\n"
1857 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1858 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1859 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1860 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1861 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1862 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1863 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1864 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1865 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1866 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1867 " p.output = ptr_f32(base, L->output_offset);\n"
1868 " p.wq_dtype = L->wq_dtype;\n"
1869 " p.wk_dtype = L->wk_dtype;\n"
1870 " p.wv_dtype = L->wv_dtype;\n"
1871 " p.wo_dtype = L->wo_dtype;\n"
1872 " p.w1_dtype = L->w1_dtype;\n"
1873 " p.w2_dtype = L->w2_dtype;\n"
1874 " ck_layer_forward_rmsnorm_swiglu_quant(&p);\n"
1875 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1876 " kv_cache_repack_head_major_inplace(p.k,\n"
1877 " p.num_kv_heads,\n"
1879 " m->kv_cache_capacity,\n"
1880 " p.aligned_head_dim);\n"
1881 " kv_cache_repack_head_major_inplace(p.v,\n"
1882 " p.num_kv_heads,\n"
1884 " m->kv_cache_capacity,\n"
1885 " p.aligned_head_dim);\n"
1887 " current = p.output;\n"
1889 " CKLayerForwardParams p = {0};\n"
1891 " p.embed_dim = m->embed_dim;\n"
1892 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1893 " p.num_heads = m->num_attention_heads;\n"
1894 " p.num_kv_heads = m->num_kv_heads;\n"
1895 " p.head_dim = m->head_dim;\n"
1896 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1897 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1898 " p.intermediate_dim = m->intermediate_size;\n"
1899 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1900 " p.eps = m->rms_norm_eps;\n"
1901 " p.rope_pos_offset = 0;\n"
1902 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1903 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1904 " p.input = current;\n"
1905 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1906 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1907 " p.wq = cptr_f32(base, L->wq_offset);\n"
1908 " p.bq = cptr_f32(base, L->bq_offset);\n"
1909 " p.wk = cptr_f32(base, L->wk_offset);\n"
1910 " p.bk = cptr_f32(base, L->bk_offset);\n"
1911 " p.wv = cptr_f32(base, L->wv_offset);\n"
1912 " p.bv = cptr_f32(base, L->bv_offset);\n"
1913 " p.wo = cptr_f32(base, L->wo_offset);\n"
1914 " p.bo = cptr_f32(base, L->bo_offset);\n"
1915 " p.w1 = cptr_f32(base, L->w1_offset);\n"
1916 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1917 " p.w2 = cptr_f32(base, L->w2_offset);\n"
1918 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1919 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1920 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1921 " p.q = ptr_f32(base, L->q_offset);\n"
1922 " p.k = ptr_f32(base, L->k_offset);\n"
1923 " p.v = ptr_f32(base, L->v_offset);\n"
1924 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1925 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1926 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1927 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1928 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1929 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1930 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1931 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1932 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1933 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1934 " p.output = ptr_f32(base, L->output_offset);\n"
1935 " ck_layer_forward_rmsnorm_swiglu(&p);\n"
1936 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1937 " kv_cache_repack_head_major_inplace(p.k,\n"
1938 " p.num_kv_heads,\n"
1940 " m->kv_cache_capacity,\n"
1941 " p.aligned_head_dim);\n"
1942 " kv_cache_repack_head_major_inplace(p.v,\n"
1943 " p.num_kv_heads,\n"
1945 " m->kv_cache_capacity,\n"
1946 " p.aligned_head_dim);\n"
1948 " current = p.output;\n"
1951 " float *final_out = ptr_f32(base, m->final_output_offset);\n"
1952 " rmsnorm_forward(current,\n"
1953 " cptr_f32(base, m->final_ln_weight_offset),\n"
1955 " ptr_f32(base, m->final_ln_rstd_offset),\n"
1958 " (int)m->aligned_embed_dim,\n"
1959 " m->rms_norm_eps);\n"
1960 " if (m->vocab_size > 0) {\n"
1961 " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1962 " gemm_nt_q4_k(final_out,\n"
1963 " cptr_void(base, m->lm_head_weight_offset),\n"
1965 " ptr_f32(base, m->logits_offset),\n"
1968 " (int)m->aligned_embed_dim);\n"
1969 " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1970 " gemm_nt_q6_k(final_out,\n"
1971 " cptr_void(base, m->lm_head_weight_offset),\n"
1973 " ptr_f32(base, m->logits_offset),\n"
1976 " (int)m->aligned_embed_dim);\n"
1978 " lm_head_forward(final_out,\n"
1979 " cptr_f32(base, m->lm_head_weight_offset),\n"
1980 " ptr_f32(base, m->logits_offset),\n"
1984 " (int)m->aligned_embed_dim);\n"
1993 "static int run_model_backward(TransformerModel *m,\n"
1994 " const int32_t *tokens,\n"
1995 " const int32_t *targets,\n"
1996 " float *loss_out)\n"
1998 " if (!m || !m->training_enabled) return 0;\n"
1999 " if (!tokens || !targets) return -1;\n"
2000 " if (m->num_layers <= 0) return -1;\n"
2001 " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
2002 " int V = m->vocab_size;\n"
2003 " int D = m->embed_dim;\n"
2004 " int aligned_D = (int)m->aligned_embed_dim;\n"
2005 " uint8_t *base = m->memory_base;\n"
2009 " float *final_out = ptr_f32(base, m->final_output_offset);\n"
2010 " float *logits = ptr_f32(base, m->logits_offset);\n"
2011 " float *d_logits = ptr_f32(base, m->d_logits_offset);\n"
2012 " float *d_final_out = ptr_f32(base, m->d_final_output_offset);\n"
2013 " float *d_final_in = ptr_f32(base, m->d_final_input_offset);\n"
2015 " float loss = 0.0f;\n"
2016 " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2017 " if (loss_out) {\n"
2018 " *loss_out = loss;\n"
2020 " lm_head_backward(final_out,\n"
2021 " cptr_f32(base, m->lm_head_weight_offset),\n"
2024 " ptr_f32(base, m->d_token_emb_offset),\n"
2025 " T, V, D, aligned_D);\n"
2026 " rmsnorm_backward(d_final_out,\n"
2027 " ptr_f32(base, m->layers[m->num_layers - 1].output_offset),\n"
2028 " cptr_f32(base, m->final_ln_weight_offset),\n"
2029 " ptr_f32(base, m->final_ln_rstd_offset),\n"
2031 " ptr_f32(base, m->d_final_ln_weight_offset),\n"
2032 " T, D, aligned_D);\n"
2034 " for (int layer = m->num_layers - 1; layer >= 0; --layer) {\n"
2035 " TrulyOptimalLayer *L = &m->layers[layer];\n"
2036 " CKLayerBackwardParams p = {0};\n"
2038 " p.embed_dim = m->embed_dim;\n"
2039 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
2040 " p.num_heads = m->num_attention_heads;\n"
2041 " p.num_kv_heads = m->num_kv_heads;\n"
2042 " p.head_dim = m->head_dim;\n"
2043 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
2044 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
2045 " p.intermediate_dim = m->intermediate_size;\n"
2046 " p.aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2047 " p.eps = m->rms_norm_eps;\n"
2048 " p.rope_pos_offset = 0;\n"
2049 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
2050 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
2051 " p.input = (layer == 0) ? ptr_f32(base, m->embedded_input_offset)\n"
2052 " : ptr_f32(base, m->layers[layer - 1].output_offset);\n"
2053 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
2054 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
2055 " p.ln1_out = cptr_f32(base, L->ln1_out_offset);\n"
2056 " p.ln1_rstd = cptr_f32(base, L->ln1_rstd_offset);\n"
2057 " p.ln2_out = cptr_f32(base, L->ln2_out_offset);\n"
2058 " p.ln2_rstd = cptr_f32(base, L->ln2_rstd_offset);\n"
2059 " p.wq = cptr_f32(base, L->wq_offset);\n"
2060 " p.bq = cptr_f32(base, L->bq_offset);\n"
2061 " p.wk = cptr_f32(base, L->wk_offset);\n"
2062 " p.bk = cptr_f32(base, L->bk_offset);\n"
2063 " p.wv = cptr_f32(base, L->wv_offset);\n"
2064 " p.bv = cptr_f32(base, L->bv_offset);\n"
2065 " p.wo = cptr_f32(base, L->wo_offset);\n"
2066 " p.bo = cptr_f32(base, L->bo_offset);\n"
2067 " p.w1 = cptr_f32(base, L->w1_offset);\n"
2068 " p.b1 = cptr_f32(base, L->b1_offset);\n"
2069 " p.w2 = cptr_f32(base, L->w2_offset);\n"
2070 " p.b2 = cptr_f32(base, L->b2_offset);\n"
2071 " p.q = cptr_f32(base, L->q_offset);\n"
2072 " p.k = cptr_f32(base, L->k_offset);\n"
2073 " p.v = cptr_f32(base, L->v_offset);\n"
2074 " p.scores = L->scores_offset ? cptr_f32(base, L->scores_offset) : NULL;\n"
2075 " p.attn_out = cptr_f32(base, L->attn_out_offset);\n"
2076 " p.residual1 = cptr_f32(base, L->residual1_offset);\n"
2077 " p.fc1_out = cptr_f32(base, L->fc1_out_offset);\n"
2078 " p.swiglu_out = cptr_f32(base, L->swiglu_out_offset);\n"
2079 " p.d_output = ptr_f32(base, L->d_output_offset);\n"
2080 " p.d_input = ptr_f32(base, L->d_input_offset);\n"
2081 " p.d_ln1_gamma = ptr_f32(base, L->d_ln1_gamma_offset);\n"
2082 " p.d_ln2_gamma = ptr_f32(base, L->d_ln2_gamma_offset);\n"
2083 " p.d_wq = ptr_f32(base, L->d_wq_offset);\n"
2084 " p.d_bq = ptr_f32(base, L->d_bq_offset);\n"
2085 " p.d_wk = ptr_f32(base, L->d_wk_offset);\n"
2086 " p.d_bk = ptr_f32(base, L->d_bk_offset);\n"
2087 " p.d_wv = ptr_f32(base, L->d_wv_offset);\n"
2088 " p.d_bv = ptr_f32(base, L->d_bv_offset);\n"
2089 " p.d_wo = ptr_f32(base, L->d_wo_offset);\n"
2090 " p.d_bo = ptr_f32(base, L->d_bo_offset);\n"
2091 " p.d_w1 = ptr_f32(base, L->d_w1_offset);\n"
2092 " p.d_b1 = ptr_f32(base, L->d_b1_offset);\n"
2093 " p.d_w2 = ptr_f32(base, L->d_w2_offset);\n"
2094 " p.d_b2 = ptr_f32(base, L->d_b2_offset);\n"
2095 " p.d_ln1_out = ptr_f32(base, L->d_ln1_out_offset);\n"
2096 " p.d_q = ptr_f32(base, L->d_q_offset);\n"
2097 " p.d_k = ptr_f32(base, L->d_k_offset);\n"
2098 " p.d_v = ptr_f32(base, L->d_v_offset);\n"
2099 " p.d_scores = ptr_f32(base, L->d_scores_offset);\n"
2100 " p.d_attn_out = ptr_f32(base, L->d_attn_out_offset);\n"
2101 " p.d_proj_tmp = ptr_f32(base, L->d_proj_tmp_offset);\n"
2102 " p.d_residual1 = ptr_f32(base, L->d_residual1_offset);\n"
2103 " p.d_ln2_out = ptr_f32(base, L->d_ln2_out_offset);\n"
2104 " p.d_fc1_out = ptr_f32(base, L->d_fc1_out_offset);\n"
2105 " p.d_swiglu_out = ptr_f32(base, L->d_swiglu_out_offset);\n"
2106 " p.d_mlp_out = ptr_f32(base, L->d_mlp_out_offset);\n"
2108 " const float *src = (layer == m->num_layers - 1)\n"
2110 " : ptr_f32(base, m->layers[layer + 1].d_input_offset);\n"
2111 " memcpy(p.d_output, src, (size_t)T * (size_t)aligned_D * sizeof(float));\n"
2113 " ck_layer_backward_rmsnorm_swiglu(&p);\n"
2117 " TrulyOptimalLayer *L0 = &m->layers[0];\n"
2118 " embedding_backward(tokens,\n"
2120 " ptr_f32(base, L0->d_input_offset),\n"
2121 " ptr_f32(base, m->d_token_emb_offset),\n"
2122 " ptr_f32(base, m->d_pos_emb_offset),\n"
2126 " m->context_window,\n"
2127 " m->rope_theta <= 0.0f);\n"
2130 " /* SGD update is now called separately via optimizer_step() */\n"
2135 "static int parse_int_arg(const char *s, int *out)\n"
2137 " if (!s || !out) return 0;\n"
2138 " char *end = NULL;\n"
2139 " long v = strtol(s, &end, 10);\n"
2140 " if (!end || *end != '\\0') return 0;\n"
2144 "static int parse_float_arg(const char *s, float *out)\n"
2146 " if (!s || !out) return 0;\n"
2147 " char *end = NULL;\n"
2148 " double v = strtod(s, &end);\n"
2149 " if (!end || *end != '\\0') return 0;\n"
2150 " *out = (float)v;\n"
2153 "static void print_usage(const char *prog)\n"
2155 " printf(\"Usage: %%s [options]\\n\", prog);\n"
2156 " printf(\" --dump Print layout summary (layer 0 only)\\n\");\n"
2157 " printf(\" --dump-all Print layout summary for all layers\\n\");\n"
2158 " printf(\" --no-forward Skip forward pass (layout + alloc only)\\n\");\n"
2159 " printf(\" --layers N Override num_layers\\n\");\n"
2160 " printf(\" --embed N Override embed_dim\\n\");\n"
2161 " printf(\" --intermediate N Override intermediate_size\\n\");\n"
2162 " printf(\" --heads N Override num_attention_heads\\n\");\n"
2163 " printf(\" --kv-heads N Override num_kv_heads\\n\");\n"
2164 " printf(\" --vocab N Override vocab_size\\n\");\n"
2165 " printf(\" --ctx N Override context_window\\n\");\n"
2166 " printf(\" --cores N Override num_cores\\n\");\n"
2167 " printf(\" --litmus Run LM head + CE + backward litmus\\n\");\n"
2168 " printf(\" --backward Run backward pass + SGD update (requires --tokens/--targets)\\n\");\n"
2169 " printf(\" --lr F SGD learning rate (default: 1e-3 when --backward)\\n\");\n"
2170 " printf(\" --steps N Training steps (default: 1)\\n\");\n"
2171 " printf(\" --log-steps Print loss per step during training\\n\");\n"
2172 " printf(\" --strict Enable strict parity mode (single-thread + double GEMM)\\n\");\n"
2173 " printf(\" --hidden PATH Load hidden activations [T x aligned_D] f32\\n\");\n"
2174 " printf(\" --weights PATH Load LM head weights [V x aligned_D] f32 (litmus)\\n\");\n"
2175 " printf(\" --targets PATH Load target tokens [T] int32\\n\");\n"
2176 " printf(\" --model-weights PATH Load full model weights (bump format)\\n\");\n"
2177 " printf(\" --tokens PATH Load token IDs [T] int32 and build embeddings\\n\");\n"
2178 " printf(\" --out-logits PATH Write logits [T x V] f32\\n\");\n"
2179 " printf(\" --out-dlogits PATH Write d_logits [T x V] f32\\n\");\n"
2180 " printf(\" --out-dhidden PATH Write d_hidden [T x aligned_D] f32\\n\");\n"
2181 " printf(\" --out-dweights PATH Write d_weights [V x aligned_D] f32\\n\");\n"
2182 " printf(\" --out-loss PATH Write loss (single f32)\\n\");\n"
2183 " printf(\" --out-weights PATH Write model weights (flat, no header)\\n\");\n"
2184 " printf(\" --help Show this help\\n\");\n"
2186 "static int read_floats(const char *path, float *dst, size_t count)\n"
2188 " if (!path || !dst) return -1;\n"
2189 " FILE *f = fopen(path, \"rb\");\n"
2191 " perror(\"fopen\");\n"
2194 " size_t got = fread(dst, sizeof(float), count, f);\n"
2196 " return got == count ? 0 : -1;\n"
2198 "static int read_ints(const char *path, int32_t *dst, size_t count)\n"
2200 " if (!path || !dst) return -1;\n"
2201 " FILE *f = fopen(path, \"rb\");\n"
2203 " perror(\"fopen\");\n"
2206 " size_t got = fread(dst, sizeof(int32_t), count, f);\n"
2208 " return got == count ? 0 : -1;\n"
2210 "static int read_floats_file(FILE *f, float *dst, size_t count)\n"
2212 " if (!f || !dst) return -1;\n"
2213 " size_t got = fread(dst, sizeof(float), count, f);\n"
2214 " return got == count ? 0 : -1;\n"
2216 "static int read_bytes_file(FILE *f, void *dst, size_t bytes)\n"
2218 " if (!f || !dst) return -1;\n"
2219 " size_t got = fread(dst, 1, bytes, f);\n"
2220 " return got == bytes ? 0 : -1;\n"
2222 "static int write_floats_file(FILE *f, const float *src, size_t count)\n"
2224 " if (!f || !src) return -1;\n"
2225 " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2226 " return wrote == count ? 0 : -1;\n"
2228 "static int write_bytes_file(FILE *f, const void *src, size_t bytes)\n"
2230 " if (!f || !src) return -1;\n"
2231 " size_t wrote = fwrite(src, 1, bytes, f);\n"
2232 " return wrote == bytes ? 0 : -1;\n"
2234 "static int read_weight_file(FILE *f, CKDataType dtype, void *dst, size_t n_elements)\n"
2236 " if (!f || !dst) return -1;\n"
2237 " if (dtype == CK_DT_FP32) {\n"
2238 " return read_floats_file(f, (float *)dst, n_elements);\n"
2240 " return read_bytes_file(f, dst, ck_dtype_row_bytes(dtype, n_elements));\n"
2242 "static int write_weight_file(FILE *f, CKDataType dtype, const void *src, size_t n_elements)\n"
2244 " if (!f || !src) return -1;\n"
2245 " if (dtype == CK_DT_FP32) {\n"
2246 " return write_floats_file(f, (const float *)src, n_elements);\n"
2248 " return write_bytes_file(f, src, ck_dtype_row_bytes(dtype, n_elements));\n"
2250 "static int skip_bump_header(FILE *f)\n"
2252 " if (!f) return -1;\n"
2254 " if (fread(magic, 1, 8, f) != 8) return -1;\n"
2255 " if (memcmp(magic, \"BUMPWGT3\", 8) == 0) {\n"
2256 " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2257 " uint32_t dtype_len = 0;\n"
2258 " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) return -1;\n"
2259 " if (fseek(f, (long)dtype_len, SEEK_CUR) != 0) return -1;\n"
2262 " if (memcmp(magic, \"BUMPWGT2\", 8) == 0) {\n"
2263 " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2266 " if (fseek(f, 0, SEEK_SET) != 0) return -1;\n"
2269 "static int load_model_weights(const char *path, TransformerModel *m)\n"
2271 " if (!path || !m || !m->memory_base) return -1;\n"
2272 " FILE *f = fopen(path, \"rb\");\n"
2274 " perror(\"fopen\");\n"
2277 " if (skip_bump_header(f) < 0) {\n"
2281 " uint8_t *base = m->memory_base;\n"
2282 " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2283 " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2284 " if (read_weight_file(f, m->token_emb_dtype, ptr_u8(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2285 " if (read_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2286 " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2288 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2289 " TrulyOptimalLayer *L = &m->layers[layer];\n"
2290 " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2291 " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2292 " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2293 " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2294 " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2295 " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2296 " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2297 " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2299 " if (read_floats_file(f, ptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2300 " if (read_floats_file(f, ptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2301 " if (read_weight_file(f, L->wq_dtype, ptr_u8(base, L->wq_offset), q_w) != 0) goto fail;\n"
2302 " if (read_floats_file(f, ptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2303 " if (read_weight_file(f, L->wk_dtype, ptr_u8(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2304 " if (read_floats_file(f, ptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2305 " if (read_weight_file(f, L->wv_dtype, ptr_u8(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2306 " if (read_floats_file(f, ptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2307 " if (read_weight_file(f, L->wo_dtype, ptr_u8(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2308 " if (read_floats_file(f, ptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2309 " if (read_weight_file(f, L->w1_dtype, ptr_u8(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2310 " if (read_floats_file(f, ptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2311 " if (read_weight_file(f, L->w2_dtype, ptr_u8(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2312 " if (read_floats_file(f, ptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2315 " if (read_floats_file(f, ptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2316 " if (read_floats_file(f, ptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2324 "static int save_model_weights(const char *path, const TransformerModel *m)\n"
2326 " if (!path || !m || !m->memory_base) return -1;\n"
2327 " FILE *f = fopen(path, \"wb\");\n"
2329 " perror(\"fopen\");\n"
2332 " uint8_t *base = m->memory_base;\n"
2333 " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2334 " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2335 " if (write_weight_file(f, m->token_emb_dtype, cptr_void(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2336 " if (write_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2337 " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2339 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2340 " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2341 " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2342 " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2343 " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2344 " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2345 " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2346 " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2347 " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2348 " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2350 " if (write_floats_file(f, cptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2351 " if (write_floats_file(f, cptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2352 " if (write_weight_file(f, L->wq_dtype, cptr_void(base, L->wq_offset), q_w) != 0) goto fail;\n"
2353 " if (write_floats_file(f, cptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2354 " if (write_weight_file(f, L->wk_dtype, cptr_void(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2355 " if (write_floats_file(f, cptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2356 " if (write_weight_file(f, L->wv_dtype, cptr_void(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2357 " if (write_floats_file(f, cptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2358 " if (write_weight_file(f, L->wo_dtype, cptr_void(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2359 " if (write_floats_file(f, cptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2360 " if (write_weight_file(f, L->w1_dtype, cptr_void(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2361 " if (write_floats_file(f, cptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2362 " if (write_weight_file(f, L->w2_dtype, cptr_void(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2363 " if (write_floats_file(f, cptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2366 " if (write_floats_file(f, cptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2367 " if (write_floats_file(f, cptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2375 "static void embed_tokens(const TransformerModel *m, const int32_t *tokens, int token_count)\n"
2377 " if (!m || !m->memory_base || !tokens) return;\n"
2378 " const uint8_t *base = m->memory_base;\n"
2379 " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2380 " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2381 " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2382 " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2383 " int T = m->context_window;\n"
2384 " int D = m->embed_dim;\n"
2385 " int aligned_D = (int)m->aligned_embed_dim;\n"
2386 " for (int t = 0; t < T; ++t) {\n"
2387 " float *dst = out + (size_t)t * aligned_D;\n"
2388 " if (t < token_count) {\n"
2389 " int id = tokens[t];\n"
2390 " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2391 " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2392 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2393 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2394 " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2395 " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2396 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2397 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2398 " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2400 " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2401 " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2403 " if (aligned_D > D) {\n"
2404 " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2406 " if (m->rope_theta <= 0.0f) {\n"
2407 " const float *p = pos + (size_t)t * aligned_D;\n"
2408 " for (int d = 0; d < D; ++d) {\n"
2409 " dst[d] += p[d];\n"
2413 " memset(dst, 0, (size_t)aligned_D * sizeof(float));\n"
2417 "static void embed_token_at(const TransformerModel *m, int32_t token, int t)\n"
2419 " if (!m || !m->memory_base) return;\n"
2420 " if (t < 0 || t >= m->context_window) return;\n"
2421 " const uint8_t *base = m->memory_base;\n"
2422 " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2423 " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2424 " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2425 " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2426 " int D = m->embed_dim;\n"
2427 " int aligned_D = (int)m->aligned_embed_dim;\n"
2428 " int id = (int)token;\n"
2429 " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2430 " float *dst = out + (size_t)t * aligned_D;\n"
2431 " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2432 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2433 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2434 " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2435 " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2436 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2437 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2438 " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2440 " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2441 " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2443 " if (aligned_D > D) {\n"
2444 " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2446 " if (m->rope_theta <= 0.0f) {\n"
2447 " const float *p = pos + (size_t)t * aligned_D;\n"
2448 " for (int d = 0; d < D; ++d) {\n"
2449 " dst[d] += p[d];\n"
2453 "static int write_floats(const char *path, const float *src, size_t count)\n"
2455 " if (!path || !src) return -1;\n"
2456 " FILE *f = fopen(path, \"wb\");\n"
2458 " perror(\"fopen\");\n"
2461 " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2463 " return wrote == count ? 0 : -1;\n"
2465 "static int write_float_scalar(const char *path, float v)\n"
2467 " if (!path) return -1;\n"
2468 " FILE *f = fopen(path, \"wb\");\n"
2470 " perror(\"fopen\");\n"
2473 " size_t wrote = fwrite(&v, sizeof(float), 1, f);\n"
2475 " return wrote == 1 ? 0 : -1;\n"
2477 "static void lm_head_forward(const float *hidden,\n"
2478 " const float *weights,\n"
2480 " int T, int V, int D, int aligned_D)\n"
2482 " for (int t = 0; t < T; ++t) {\n"
2483 " const float *h = hidden + (size_t)t * aligned_D;\n"
2484 " float *out = logits + (size_t)t * V;\n"
2485 " for (int v = 0; v < V; ++v) {\n"
2486 " const float *w = weights + (size_t)v * aligned_D;\n"
2487 " float sum = 0.0f;\n"
2488 " for (int d = 0; d < D; ++d) {\n"
2489 " sum += h[d] * w[d];\n"
2495 "static void softmax_cross_entropy(const float *logits,\n"
2496 " const int32_t *targets,\n"
2498 " float *d_logits,\n"
2499 " float *loss_out)\n"
2501 " double total = 0.0;\n"
2502 " for (int t = 0; t < T; ++t) {\n"
2503 " const float *row = logits + (size_t)t * V;\n"
2504 " float *drow = d_logits + (size_t)t * V;\n"
2505 " int target = targets[t];\n"
2506 " float max_logit = row[0];\n"
2507 " for (int v = 1; v < V; ++v) {\n"
2508 " if (row[v] > max_logit) max_logit = row[v];\n"
2510 " double sum_exp = 0.0;\n"
2511 " for (int v = 0; v < V; ++v) {\n"
2512 " drow[v] = expf(row[v] - max_logit);\n"
2513 " sum_exp += drow[v];\n"
2515 " float inv_sum = 1.0f / (float)sum_exp;\n"
2516 " for (int v = 0; v < V; ++v) {\n"
2517 " drow[v] *= inv_sum;\n"
2519 " double logsum = (double)max_logit + log(sum_exp);\n"
2520 " total += logsum - (double)row[target];\n"
2521 " drow[target] -= 1.0f;\n"
2522 " float scale = 1.0f / (float)T;\n"
2523 " for (int v = 0; v < V; ++v) {\n"
2524 " drow[v] *= scale;\n"
2527 " if (loss_out) {\n"
2528 " *loss_out = (float)(total / (double)T);\n"
2531 "static void lm_head_backward(const float *hidden,\n"
2532 " const float *weights,\n"
2533 " const float *d_logits,\n"
2534 " float *d_hidden,\n"
2535 " float *d_weights,\n"
2536 " int T, int V, int D, int aligned_D)\n"
2538 " size_t dh_count = (size_t)T * aligned_D;\n"
2539 " size_t dw_count = (size_t)V * aligned_D;\n"
2540 " for (size_t i = 0; i < dh_count; ++i) d_hidden[i] = 0.0f;\n"
2541 " for (size_t i = 0; i < dw_count; ++i) d_weights[i] = 0.0f;\n"
2542 " for (int t = 0; t < T; ++t) {\n"
2543 " const float *dlog = d_logits + (size_t)t * V;\n"
2544 " for (int d = 0; d < D; ++d) {\n"
2545 " double sum = 0.0;\n"
2546 " for (int v = 0; v < V; ++v) {\n"
2547 " sum += (double)dlog[v] * (double)weights[(size_t)v * aligned_D + d];\n"
2549 " d_hidden[(size_t)t * aligned_D + d] = (float)sum;\n"
2552 " for (int v = 0; v < V; ++v) {\n"
2553 " float *dw = d_weights + (size_t)v * aligned_D;\n"
2554 " for (int d = 0; d < D; ++d) {\n"
2555 " double sum = 0.0;\n"
2556 " for (int t = 0; t < T; ++t) {\n"
2557 " sum += (double)d_logits[(size_t)t * V + v] * (double)hidden[(size_t)t * aligned_D + d];\n"
2559 " dw[d] = (float)sum;\n"
2565 "static void dump_layer_offsets(const TransformerModel *m, int layer)\n"
2567 " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2568 " printf(\"Layer %%d offsets (bytes):\\n\", layer);\n"
2569 " printf(\" ln1_gamma=%%zu ln2_gamma=%%zu wq=%%zu wk=%%zu wv=%%zu wo=%%zu w1=%%zu w2=%%zu\\n\",\n"
2570 " L->ln1_gamma_offset, L->ln2_gamma_offset, L->wq_offset, L->wk_offset,\n"
2571 " L->wv_offset, L->wo_offset, L->w1_offset, L->w2_offset);\n"
2572 " printf(\" ln1_out=%%zu q=%%zu k=%%zu v=%%zu scores=%%zu attn_out=%%zu\\n\",\n"
2573 " L->ln1_out_offset, L->q_offset, L->k_offset, L->v_offset,\n"
2574 " L->scores_offset, L->attn_out_offset);\n"
2575 " printf(\" proj_tmp=%%zu residual1=%%zu ln2_out=%%zu fc1_out=%%zu swiglu_out=%%zu mlp_out=%%zu output=%%zu\\n\",\n"
2576 " L->proj_tmp_offset, L->residual1_offset, L->ln2_out_offset,\n"
2577 " L->fc1_out_offset, L->swiglu_out_offset, L->mlp_out_offset, L->output_offset);\n"
2579 "static void dump_layout(const TransformerModel *m, int dump_all)\n"
2581 " size_t bytes = m->total_bytes;\n"
2582 " printf(\"Model config:\\n\");\n"
2583 " printf(\" layers=%%d embed=%%d intermediate=%%d heads=%%d kv_heads=%%d\\n\",\n"
2584 " m->num_layers, m->embed_dim, m->intermediate_size, m->num_attention_heads, m->num_kv_heads);\n"
2585 " printf(\" head_dim=%%d vocab=%%d ctx=%%d cores=%%d\\n\",\n"
2586 " m->head_dim, m->vocab_size, m->context_window, m->num_cores);\n"
2587 " printf(\" eps=%%.6g rope_theta=%%.6g\\n\", m->rms_norm_eps, m->rope_theta);\n"
2588 " printf(\"Aligned dims (elements): embed=%%zu head=%%zu ctx=%%zu\\n\",\n"
2589 " m->aligned_embed_dim, m->aligned_head_dim, m->aligned_attn_context_window);\n"
2590 " printf(\"Memory: total_bytes=%%zu\\n\", bytes);\n"
2591 " printf(\"Global offsets (bytes): token=%%zu pos=%%zu embedded=%%zu layers_start=%%zu\\n\",\n"
2592 " m->token_emb_offset, m->pos_emb_offset, m->embedded_input_offset, m->layers_start_offset);\n"
2593 " printf(\"Final offsets (bytes): final_ln_w=%%zu final_ln_b=%%zu final_ln_mean=%%zu final_ln_rstd=%%zu\\n\",\n"
2594 " m->final_ln_weight_offset, m->final_ln_bias_offset,\n"
2595 " m->final_ln_mean_offset, m->final_ln_rstd_offset);\n"
2596 " printf(\"LM/logits offsets (bytes): lm_head=%%zu logits=%%zu\\n\",\n"
2597 " m->lm_head_weight_offset, m->logits_offset);\n"
2598 " if (m->num_layers > 0) {\n"
2599 " dump_layer_offsets(m, 0);\n"
2600 " if (dump_all) {\n"
2601 " for (int i = 1; i < m->num_layers; ++i) {\n"
2602 " dump_layer_offsets(m, i);\n"
2611 "int main(int argc, char **argv)\n"
2614 " int dump_all = 0;\n"
2615 " int no_forward = 0;\n"
2616 " int run_litmus = 0;\n"
2617 " int run_backward = 0;\n"
2618 " const char *litmus_hidden = NULL;\n"
2619 " const char *litmus_weights = NULL;\n"
2620 " const char *litmus_targets = NULL;\n"
2621 " const char *model_weights = NULL;\n"
2622 " const char *tokens_path = NULL;\n"
2623 " const char *out_logits = NULL;\n"
2624 " const char *out_dlogits = NULL;\n"
2625 " const char *out_dhidden = NULL;\n"
2626 " const char *out_dweights = NULL;\n"
2627 " const char *out_loss = NULL;\n"
2628 " const char *out_weights = NULL;\n"
2630 " int log_steps = 0;\n"
2631 " int strict = 0;\n"
2632 " int32_t *tokens = NULL;\n"
2633 " int32_t *targets = NULL;\n"
2634 " TransformerModel m = {0};\n"
2635 " memcpy(m.magic, \"BUMPWGT3\", 8);\n"
2637 " m.model_type = 0;\n"
2638 " m.num_layers = %d;\n"
2639 " m.embed_dim = %d;\n"
2640 " m.intermediate_size = %d;\n"
2641 " m.num_attention_heads = %d;\n"
2642 " m.num_kv_heads = %d;\n"
2643 " m.vocab_size = %d;\n"
2644 " m.context_window = %d;\n"
2645 " m.rms_norm_eps = %.9g;\n"
2646 " m.rope_theta = %.9g;\n"
2647 " m.num_cores = 1;\n"
2648 " m.task_type = TASK_LM;\n"
2649 " m.optimizer = OPTIMIZER_SGD;\n"
2650 " m.learning_rate = 0.0f;\n"
2651 " for (int i = 1; i < argc; ++i) {\n"
2652 " if (strcmp(argv[i], \"--dump\") == 0) {\n"
2656 " if (strcmp(argv[i], \"--dump-all\") == 0) {\n"
2661 " if (strcmp(argv[i], \"--no-forward\") == 0) {\n"
2662 " no_forward = 1;\n"
2665 " if (strcmp(argv[i], \"--strict\") == 0) {\n"
2669 " if (strcmp(argv[i], \"--litmus\") == 0) {\n"
2670 " run_litmus = 1;\n"
2673 " if (strcmp(argv[i], \"--backward\") == 0) {\n"
2674 " run_backward = 1;\n"
2677 " if (strcmp(argv[i], \"--lr\") == 0 && i + 1 < argc) {\n"
2678 " parse_float_arg(argv[++i], &m.learning_rate);\n"
2681 " if (strcmp(argv[i], \"--help\") == 0) {\n"
2682 " print_usage(argv[0]);\n"
2685 " if (strcmp(argv[i], \"--hidden\") == 0 && i + 1 < argc) {\n"
2686 " litmus_hidden = argv[++i];\n"
2689 " if (strcmp(argv[i], \"--weights\") == 0 && i + 1 < argc) {\n"
2690 " litmus_weights = argv[++i];\n"
2693 " if (strcmp(argv[i], \"--targets\") == 0 && i + 1 < argc) {\n"
2694 " litmus_targets = argv[++i];\n"
2697 " if (strcmp(argv[i], \"--model-weights\") == 0 && i + 1 < argc) {\n"
2698 " model_weights = argv[++i];\n"
2701 " if (strcmp(argv[i], \"--tokens\") == 0 && i + 1 < argc) {\n"
2702 " tokens_path = argv[++i];\n"
2705 " if (strcmp(argv[i], \"--out-logits\") == 0 && i + 1 < argc) {\n"
2706 " out_logits = argv[++i];\n"
2709 " if (strcmp(argv[i], \"--out-dlogits\") == 0 && i + 1 < argc) {\n"
2710 " out_dlogits = argv[++i];\n"
2713 " if (strcmp(argv[i], \"--out-dhidden\") == 0 && i + 1 < argc) {\n"
2714 " out_dhidden = argv[++i];\n"
2717 " if (strcmp(argv[i], \"--out-dweights\") == 0 && i + 1 < argc) {\n"
2718 " out_dweights = argv[++i];\n"
2721 " if (strcmp(argv[i], \"--out-loss\") == 0 && i + 1 < argc) {\n"
2722 " out_loss = argv[++i];\n"
2725 " if (strcmp(argv[i], \"--out-weights\") == 0 && i + 1 < argc) {\n"
2726 " out_weights = argv[++i];\n"
2729 " if (strcmp(argv[i], \"--steps\") == 0 && i + 1 < argc) {\n"
2730 " parse_int_arg(argv[++i], &steps);\n"
2733 " if (strcmp(argv[i], \"--log-steps\") == 0) {\n"
2737 " if (strcmp(argv[i], \"--layers\") == 0 && i + 1 < argc) {\n"
2738 " parse_int_arg(argv[++i], &m.num_layers);\n"
2741 " if (strcmp(argv[i], \"--embed\") == 0 && i + 1 < argc) {\n"
2742 " parse_int_arg(argv[++i], &m.embed_dim);\n"
2745 " if (strcmp(argv[i], \"--intermediate\") == 0 && i + 1 < argc) {\n"
2746 " parse_int_arg(argv[++i], &m.intermediate_size);\n"
2749 " if (strcmp(argv[i], \"--heads\") == 0 && i + 1 < argc) {\n"
2750 " parse_int_arg(argv[++i], &m.num_attention_heads);\n"
2753 " if (strcmp(argv[i], \"--kv-heads\") == 0 && i + 1 < argc) {\n"
2754 " parse_int_arg(argv[++i], &m.num_kv_heads);\n"
2757 " if (strcmp(argv[i], \"--vocab\") == 0 && i + 1 < argc) {\n"
2758 " parse_int_arg(argv[++i], &m.vocab_size);\n"
2761 " if (strcmp(argv[i], \"--ctx\") == 0 && i + 1 < argc) {\n"
2762 " parse_int_arg(argv[++i], &m.context_window);\n"
2765 " if (strcmp(argv[i], \"--cores\") == 0 && i + 1 < argc) {\n"
2766 " parse_int_arg(argv[++i], &m.num_cores);\n"
2769 " fprintf(stderr, \"Unknown or invalid arg: %%s\\n\", argv[i]);\n"
2770 " print_usage(argv[0]);\n"
2774 " ck_set_strict_parity(1);\n"
2776 " if (run_backward && m.learning_rate == 0.0f) {\n"
2777 " m.learning_rate = 1e-3f;\n"
2779 " m.training_enabled = run_backward;\n"
2780 " m.weight_dtype = CK_DT_FP32;\n"
2782 " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
2784 " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
2785 " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
2786 " m.weight_dtype = CK_DT_Q4_K;\n"
2787 " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
2788 " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
2789 " m.weight_dtype = CK_DT_Q6_K;\n"
2793 " init_weight_dtypes_uniform(&m, m.weight_dtype);\n"
2794 " refresh_weight_flags(&m);\n"
2795 " if (model_weights) {\n"
2796 " int dtype_rc = load_weight_dtypes(model_weights, &m);\n"
2797 " if (dtype_rc < 0) {\n"
2798 " fprintf(stderr, \"failed to read weight dtype table\\n\");\n"
2802 " if (m.training_enabled && m.weights_quantized) {\n"
2803 " fprintf(stderr, \"Quantized weights are inference-only; disable training\\n\");\n"
2806 " if (layout_model(&m) != 0) {\n"
2807 " fprintf(stderr, \"layout_model failed\\n\");\n"
2810 " if (model_weights) {\n"
2811 " if (load_model_weights(model_weights, &m) != 0) {\n"
2812 " fprintf(stderr, \"failed to load model weights\\n\");\n"
2816 " if (tokens_path) {\n"
2817 " int T = m.context_window;\n"
2818 " tokens = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2820 " fprintf(stderr, \"failed to alloc tokens\\n\");\n"
2823 " if (read_ints(tokens_path, tokens, (size_t)T) != 0) {\n"
2824 " fprintf(stderr, \"failed to read tokens\\n\");\n"
2829 " if (!run_backward) {\n"
2830 " embed_tokens(&m, tokens, T);\n"
2835 " if (run_backward) {\n"
2836 " if (!litmus_targets) {\n"
2837 " fprintf(stderr, \"backward requires --targets\\n\");\n"
2840 " int T = m.context_window;\n"
2841 " targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2842 " if (!targets) {\n"
2843 " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2846 " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2847 " fprintf(stderr, \"failed to read targets\\n\");\n"
2849 " targets = NULL;\n"
2854 " dump_layout(&m, dump_all);\n"
2856 " if (run_litmus) {\n"
2857 " if (!litmus_hidden || !litmus_weights || !litmus_targets) {\n"
2858 " fprintf(stderr, \"litmus requires --hidden, --weights, and --targets\\n\");\n"
2861 " int T = m.context_window;\n"
2862 " int V = m.vocab_size;\n"
2863 " int D = m.embed_dim;\n"
2864 " int aligned_D = (int)m.aligned_embed_dim;\n"
2865 " float *hidden = ptr_f32(m.memory_base, m.final_output_offset);\n"
2866 " float *weights = ptr_f32(m.memory_base, m.lm_head_weight_offset);\n"
2867 " float *logits = ptr_f32(m.memory_base, m.logits_offset);\n"
2868 " if (read_floats(litmus_hidden, hidden, (size_t)T * aligned_D) != 0) {\n"
2869 " fprintf(stderr, \"failed to read hidden\\n\");\n"
2872 " if (read_floats(litmus_weights, weights, (size_t)V * aligned_D) != 0) {\n"
2873 " fprintf(stderr, \"failed to read weights\\n\");\n"
2876 " int32_t *targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2877 " if (!targets) {\n"
2878 " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2881 " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2882 " fprintf(stderr, \"failed to read targets\\n\");\n"
2886 " float *d_logits = (float *)calloc((size_t)T * V, sizeof(float));\n"
2887 " float *d_hidden = (float *)calloc((size_t)T * aligned_D, sizeof(float));\n"
2888 " float *d_weights = (float *)calloc((size_t)V * aligned_D, sizeof(float));\n"
2889 " if (!d_logits || !d_hidden || !d_weights) {\n"
2890 " fprintf(stderr, \"failed to alloc grads\\n\");\n"
2892 " free(d_logits);\n"
2893 " free(d_hidden);\n"
2894 " free(d_weights);\n"
2897 " lm_head_forward(hidden, weights, logits, T, V, D, aligned_D);\n"
2898 " float loss = 0.0f;\n"
2899 " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2900 " lm_head_backward(hidden, weights, d_logits, d_hidden, d_weights, T, V, D, aligned_D);\n"
2901 " if (out_logits) write_floats(out_logits, logits, (size_t)T * V);\n"
2902 " if (out_dlogits) write_floats(out_dlogits, d_logits, (size_t)T * V);\n"
2903 " if (out_dhidden) write_floats(out_dhidden, d_hidden, (size_t)T * aligned_D);\n"
2904 " if (out_dweights) write_floats(out_dweights, d_weights, (size_t)V * aligned_D);\n"
2905 " if (out_loss) write_float_scalar(out_loss, loss);\n"
2906 " if (!out_loss) printf(\"loss=%%.6f\\n\", loss);\n"
2908 " free(d_logits);\n"
2909 " free(d_hidden);\n"
2910 " free(d_weights);\n"
2911 " ck_huge_free(m.memory_base, m.total_bytes);\n"
2912 " free(m.layers);\n"
2915 " // TODO: load weights into m.memory_base using the offsets above.\n"
2916 " // TODO: write token/pos embeddings into embedded_input_offset.\n"
2917 " if (!run_backward) {\n"
2918 " if (!no_forward) {\n"
2919 " run_model_forward(&m);\n"
2922 " if (!tokens || !targets) {\n"
2923 " fprintf(stderr, \"backward requires --tokens and --targets\\n\");\n"
2926 " if (steps < 1) steps = 1;\n"
2927 " float loss = 0.0f;\n"
2928 " for (int step = 0; step < steps; ++step) {\n"
2929 " embed_tokens(&m, tokens, m.context_window);\n"
2930 " run_model_forward(&m);\n"
2931 " if (run_model_backward(&m, tokens, targets, &loss) != 0) {\n"
2932 " fprintf(stderr, \"backward failed\\n\");\n"
2935 " if (log_steps) {\n"
2936 " printf(\"step %%d loss=%%.6f\\n\", step, loss);\n"
2939 " if (out_loss) {\n"
2940 " write_float_scalar(out_loss, loss);\n"
2943 " if (out_logits) {\n"
2944 " write_floats(out_logits, ptr_f32(m.memory_base, m.logits_offset),\n"
2945 " (size_t)m.context_window * (size_t)m.vocab_size);\n"
2947 " if (out_weights) {\n"
2948 " if (save_model_weights(out_weights, &m) != 0) {\n"
2949 " fprintf(stderr, \"failed to save model weights\\n\");\n"
2953 " ck_huge_free(m.memory_base, m.total_bytes);\n"
2954 " free(m.layers);\n"
static void emit_shape_expr(FILE *out, const CKDimToken *shape)
static int ck_buffer_uses_weight_dtype(const CKBufferSpec *spec)
static int emit_runtime_preamble(FILE *out)
int ck_codegen_emit_runtime(const CKIRGraph *forward, const char *path, CKEmitMode mode)
static const CKKernelSpec * ck_find_kernel_spec(const char *name)
static const CKBufferSpec * ck_find_buffer_spec(const char *name)
static void emit_global_offset_fields(FILE *out)
static const char * ck_weight_dtype_expr(const CKBufferSpec *spec)
static const char * op_name(CKOpType op)
static void emit_offset_field(FILE *out, const char *name)
static int emit_kernel_manifest(const CKIRGraph *forward, const char *runtime_path)
static const char * ck_first_layer_buffer_name(void)
static int ck_plan_step_enabled(const CKPlanStep *step, const CKIRGraph *cfg)
static void emit_library_api(FILE *out, const CKIRGraph *forward)
static void emit_layer_allocations(FILE *out)
static void emit_sgd_update(FILE *out)
static void emit_dim_expr(FILE *out, CKDimKind dim)
static int emit_plan_sources(FILE *f, const CKPlanStep *plan, size_t plan_count, const CKIRGraph *cfg, const char **seen, size_t *seen_count, size_t seen_cap)
static void emit_model_struct(FILE *out)
static void emit_global_allocations(FILE *out)
static int ck_buffer_should_alloc(const CKBufferSpec *spec)
static void emit_zero_grad(FILE *out)
static int emit_unique_source(FILE *f, const char *path, const char **seen, size_t *seen_count, size_t seen_cap)
void ck_codegen_c_skeleton(const CKIRGraph *forward, const CKIRGraph *backward, FILE *out)
static void emit_bump_bytes_assignment(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
static void emit_layer_offsets_struct(FILE *out)
static void emit_training_conditional_assignment(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
static void emit_bump_bytes_assignment_weight_dtype(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape, const char *dtype_expr)
static void emit_global_aliases_to_layer(FILE *out)
#define CKERNEL_MAX_KERNEL_SOURCES
const CKPlanStep ck_decoder_forward_plan[]
@ CK_DIM_ALIGNED_INTERMEDIATE
const size_t ck_decoder_backward_plan_count
const size_t ck_decoder_forward_plan_count
const CKPlanStep ck_decoder_backward_plan[]
const CKKernelSpec ck_kernel_specs[]
const CKBufferSpec ck_decoder_buffers[]
const size_t ck_kernel_spec_count
const size_t ck_decoder_buffer_count
int ck_ir_validate_supported(const CKIRGraph *graph)