27 default:
return "UNKNOWN";
70 return (strcmp(spec->
name,
"token_emb") == 0) ||
71 (strcmp(spec->
name,
"wq") == 0) ||
72 (strcmp(spec->
name,
"wk") == 0) ||
73 (strcmp(spec->
name,
"wv") == 0) ||
74 (strcmp(spec->
name,
"wo") == 0) ||
75 (strcmp(spec->
name,
"w1") == 0) ||
76 (strcmp(spec->
name,
"w2") == 0) ||
77 (strcmp(spec->
name,
"lm_head_weight") == 0);
86 if (strcmp(spec->
name,
"token_emb") == 0) {
87 return "m->token_emb_dtype";
89 if (strcmp(spec->
name,
"lm_head_weight") == 0) {
90 return "m->lm_head_weight_dtype";
95 if (strcmp(spec->
name,
"wq") == 0)
return "L->wq_dtype";
96 if (strcmp(spec->
name,
"wk") == 0)
return "L->wk_dtype";
97 if (strcmp(spec->
name,
"wv") == 0)
return "L->wv_dtype";
98 if (strcmp(spec->
name,
"wo") == 0)
return "L->wo_dtype";
99 if (strcmp(spec->
name,
"w1") == 0)
return "L->w1_dtype";
100 if (strcmp(spec->
name,
"w2") == 0)
return "L->w2_dtype";
108 fprintf(out,
" size_t %s_offset;\n", name);
113 fprintf(out,
"typedef struct {\n");
125 " CKDataType wq_dtype;\n"
126 " CKDataType wk_dtype;\n"
127 " CKDataType wv_dtype;\n"
128 " CKDataType wo_dtype;\n"
129 " CKDataType w1_dtype;\n"
130 " CKDataType w2_dtype;\n");
131 fprintf(out,
"} LayerOffsets;\n\n");
151 "typedef LayerOffsets TrulyOptimalLayer;\n\n"
154 " uint32_t version;\n"
155 " uint32_t model_type;\n"
160 " int context_window;\n"
161 " int intermediate_size;\n"
163 " size_t aligned_embed_dim;\n"
164 " size_t aligned_head_dim;\n"
165 " size_t aligned_attn_context_window;\n"
168 " int tokens_per_core;\n"
169 " int num_attention_heads;\n"
170 " int num_kv_heads;\n"
172 " float rms_norm_eps;\n"
173 " float rope_theta;\n"
175 " uint8_t *memory_base;\n"
176 " size_t total_bytes;\n"
177 " size_t elem_bytes;\n"
178 " CKDataType weight_dtype;\n"
179 " CKDataType token_emb_dtype;\n"
180 " CKDataType pos_emb_dtype;\n"
181 " CKDataType lm_head_weight_dtype;\n"
182 " bool weights_mixed;\n"
183 " bool weights_quantized;\n"
184 " size_t layer_stride;\n"
186 " size_t layers_start_offset;\n");
192 " TrulyOptimalLayer *layers;\n"
194 " GradientStorage gradients;\n"
195 " bool training_enabled;\n"
196 " float learning_rate;\n"
197 " int lr_warmup_steps;\n"
198 " float lr_warmup_init;\n"
199 " float grad_clip;\n"
200 " size_t training_cache_samples;\n"
201 " int active_tokens;\n"
202 " TaskType task_type;\n"
203 " OptimizerType optimizer;\n"
204 " uint64_t optimizer_step;\n"
205 " float adam_beta1;\n"
206 " float adam_beta2;\n"
208 " float weight_decay;\n"
209 " bool ema_enabled;\n"
210 " float ema_decay;\n"
211 " bool optimizer_state_initialized;\n"
213 " bool seq_cls_enabled;\n"
214 " int seq_cls_num_classes;\n"
215 " int seq_cls_pooling;\n"
216 " size_t seq_cls_weight_offset;\n"
217 " size_t seq_cls_bias_offset;\n"
219 " bool kv_cache_enabled;\n"
220 " int kv_cache_capacity;\n"
221 " int kv_cache_tokens;\n"
223 " long *training_data_buffer;\n"
224 " long num_training_tokens;\n"
226 " uint8_t checksum[32];\n"
227 " uint8_t reserved[32];\n"
228 "} TransformerModel;\n\n");
234 case CK_DIM_TOKENS: fprintf(out,
"(size_t)m->context_window");
break;
235 case CK_DIM_EMBED: fprintf(out,
"(size_t)m->embed_dim");
break;
239 case CK_DIM_NUM_HEADS: fprintf(out,
"(size_t)m->num_attention_heads");
break;
244 case CK_DIM_VOCAB: fprintf(out,
"(size_t)m->vocab_size");
break;
252 for (
int i = 0; i < 4; ++i) {
261 if (shape[i].mult != 1) {
262 fprintf(out,
" * %d", shape[i].mult);
264 if (shape[i].div != 1) {
265 fprintf(out,
" / %d", shape[i].div);
277 const char *struct_prefix,
281 fprintf(out,
"%s%s%s_offset = bump_bytes(&off, (", indent, struct_prefix, name);
283 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
288 const char *struct_prefix,
291 const char *dtype_expr)
293 fprintf(out,
"%s%s%s_offset = bump_bytes(&off, ck_dtype_row_bytes(%s, (",
294 indent, struct_prefix, name, dtype_expr);
296 fprintf(out,
")), CACHELINE_BYTES);\n");
301 const char *struct_prefix,
306 fprintf(out,
"%s%s%s_offset = m->training_enabled ? bump_bytes(&off, (", indent, struct_prefix, name);
308 fprintf(out,
") * elem_bytes, CACHELINE_BYTES) : 0;\n");
328 fprintf(out,
" m->%s_offset = m->%s_offset;\n", spec->
name, spec->
alias_of);
333 fprintf(out,
" if (m->rope_theta > 0.0f) {\n");
334 fprintf(out,
" m->%s_offset = bump_bytes(&off, (", spec->
name);
336 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
337 fprintf(out,
" } else {\n");
338 fprintf(out,
" m->%s_offset = 0;\n", spec->
name);
339 fprintf(out,
" }\n");
343 fprintf(out,
" if (m->training_enabled) {\n");
344 fprintf(out,
" m->%s_offset = bump_bytes(&off, (", spec->
name);
346 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
347 fprintf(out,
" } else {\n");
348 fprintf(out,
" m->%s_offset = 0;\n", spec->
name);
349 fprintf(out,
" }\n");
378 fprintf(out,
" if (m->training_enabled) {\n");
379 fprintf(out,
" L->%s_offset = bump_bytes(&off, (", spec->
name);
381 fprintf(out,
") * elem_bytes, CACHELINE_BYTES);\n");
382 fprintf(out,
" } else {\n");
383 fprintf(out,
" L->%s_offset = 0;\n", spec->
name);
384 fprintf(out,
" }\n");
410 " if (m->num_layers > 0) {\n"
411 " m->%s_offset = m->layers[m->num_layers - 1].%s_offset;\n"
413 " m->%s_offset = 0;\n"
422 "static void zero_grad(TransformerModel *m)\n"
424 " if (!m || !m->training_enabled) return;\n"
425 " uint8_t *base = m->memory_base;\n"
426 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n");
433 fprintf(out,
" if (m->%s_offset) {\n", spec->
name);
434 fprintf(out,
" memset(base + m->%s_offset, 0, (", spec->
name);
436 fprintf(out,
") * m->elem_bytes);\n");
437 fprintf(out,
" }\n");
441 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
442 " TrulyOptimalLayer *L = &m->layers[layer];\n");
448 fprintf(out,
" if (L->%s_offset) {\n", spec->
name);
449 fprintf(out,
" memset(base + L->%s_offset, 0, (", spec->
name);
451 fprintf(out,
") * m->elem_bytes);\n");
452 fprintf(out,
" }\n");
462 "static void sgd_update(TransformerModel *m, float lr)\n"
464 " if (!m || !m->training_enabled || lr == 0.0f) return;\n"
465 " uint8_t *base = m->memory_base;\n"
466 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n");
477 snprintf(grad_name,
sizeof(grad_name),
"d_%s", spec->
name);
483 " if (m->%s_offset && m->%s_offset) {\n"
484 " float *w = ptr_f32(base, m->%s_offset);\n"
485 " float *g = ptr_f32(base, m->%s_offset);\n"
487 spec->
name, grad_name, spec->
name, grad_name);
491 " for (size_t i = 0; i < count; ++i) {\n"
492 " w[i] -= lr * g[i];\n"
498 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
499 " TrulyOptimalLayer *L = &m->layers[layer];\n");
506 snprintf(grad_name,
sizeof(grad_name),
"d_%s", spec->
name);
512 " if (L->%s_offset && L->%s_offset) {\n"
513 " float *w = ptr_f32(base, L->%s_offset);\n"
514 " float *g = ptr_f32(base, L->%s_offset);\n"
516 spec->
name, grad_name, spec->
name, grad_name);
520 " for (size_t i = 0; i < count; ++i) {\n"
521 " w[i] -= lr * g[i];\n"
536 if (!path || !path[0]) {
539 for (
size_t i = 0; i < *seen_count; ++i) {
540 if (strcmp(seen[i], path) == 0) {
544 if (*seen_count >= seen_cap) {
549 seen[*seen_count] = path;
559 if (strcmp(step->
condition,
"rope_theta") == 0) {
562 if (strcmp(step->
condition,
"rope_theta>0") == 0) {
576 for (
size_t i = 0; i < plan_count; ++i) {
586 const char *src = spec->
sources[s];
617 if (!forward || !out) {
622 "/* Auto-generated skeleton from CKIRGraph.\n"
623 " * This file sketches the structure of the forward and backward\n"
624 " * execution for a decoder-only transformer. It is NOT yet a\n"
625 " * complete, runnable implementation. You can use it as a\n"
626 " * starting point to wire buffers, kernel calls, and memory layout.\n"
629 fprintf(out,
"#include \"ckernel_engine.h\"\n");
630 fprintf(out,
"#include \"ckernel_model.h\"\n");
631 fprintf(out,
"#include \"ckernel_alloc.h\"\n\n");
635 "void run_decoder_forward(TransformerModel *model /*, inputs, etc. */)\n"
637 " for (int layer = 0; layer < model->cfg.num_layers; ++layer) {\n"
638 " /* Forward pass for layer */\n");
640 int nodes_per_layer = 0;
643 for (
int i = 0; i < forward->
num_nodes; ++i) {
651 if (nodes_per_layer <= 0) {
655 fprintf(out,
" /* This layer has %d IR nodes */\n", nodes_per_layer);
657 for (
int i = 0; i < nodes_per_layer; ++i) {
659 fprintf(out,
" // L%%d: %s\n",
op_name(n->
op));
663 if (o > 0) fprintf(out,
", ");
664 fprintf(out,
"L%%d:N%d:%d", n->
id.
node, o);
667 fprintf(out,
" // inputs : [");
668 for (
int j = 0; j < n->
n_inputs; ++j) {
670 if (j > 0) fprintf(out,
", ");
674 fprintf(out,
"L%%d:N%u:%u",
681 " // TODO: bind buffers/weights and call %s kernel here\n\n",
686 " } /* end for layer */\n"
692 "void run_decoder_backward(TransformerModel *model /*, grads, etc. */)\n"
694 " for (int layer = model->cfg.num_layers - 1; layer >= 0; --layer) {\n"
695 " /* Backward pass for layer */\n");
697 int bwd_per_layer = 0;
699 for (
int i = 0; i < backward->
num_nodes; ++i) {
703 if (bwd_per_layer <= 0) bwd_per_layer = backward->
num_nodes;
705 fprintf(out,
" /* This layer has %d backward IR nodes */\n", bwd_per_layer);
707 for (
int i = 0; i < bwd_per_layer; ++i) {
709 fprintf(out,
" // L%%d: %s\n",
op_name(n->
op));
711 " // TODO: wire gradient tensors and call %s kernel here\n\n",
716 " } /* end for layer */\n"
721 "int main(int argc, char **argv)\n"
723 " (void)argc; (void)argv;\n"
724 " TransformerModel model = {0};\n"
725 " model.cfg.num_layers = %d;\n"
726 " model.cfg.hidden_size = %d;\n"
727 " model.cfg.intermediate_size = %d;\n"
728 " model.cfg.num_heads = %d;\n"
729 " model.cfg.num_kv_heads = %d;\n"
730 " model.cfg.vocab_size = %d;\n"
731 " model.cfg.context_window = %d;\n"
732 " model.cfg.rms_norm_eps = %.9g;\n"
733 " model.cfg.rope_theta = %.9g;\n"
734 " layout_transformer_from_ir(&model, NULL); /* TODO: pass IR if needed */\n"
735 " size_t bytes = model.total_bytes;\n"
736 " model.memory_base = (uint8_t *)ck_huge_alloc(bytes);\n"
737 " if (!model.memory_base) {\n"
738 " fprintf(stderr, \"Failed to allocate %%zu bytes for model\\n\", bytes);\n"
741 " // TODO: load weights into model.memory_base based on offsets\n"
742 " run_decoder_forward(&model);\n"
743 " // TODO: run_decoder_backward(&model) when training\n"
744 " ck_huge_free(model.memory_base, bytes);\n"
761 "/* Auto-generated runtime from CKIRGraph.\n"
762 " * This file wires the existing C-Kernel-Engine kernels into a\n"
763 " * decoder-only transformer forward pass.\n"
765 " * Compile (scalar): gcc -O2 generated_model.c $(cat generated_model.c.kernels) -Iinclude -lm -o generated_model\n"
766 " * Compile (AVX-512): gcc -O3 -mavx512f -mfma generated_model.c $(cat generated_model.c.kernels) -Iinclude -lm -o generated_model\n"
770 "#define _GNU_SOURCE\n"
771 "#include <stddef.h>\n"
772 "#include <stdint.h>\n"
773 "#include <stdbool.h>\n"
774 "#include <stdio.h>\n"
775 "#include <stdlib.h>\n"
776 "#include <string.h>\n"
777 "#include <math.h>\n"
778 "#include <errno.h>\n"
779 "#include <sys/types.h>\n"
780 "#include <unistd.h>\n"
781 "#include \"ckernel_engine.h\"\n"
782 "#include \"ckernel_dtype.h\"\n"
783 "#include \"ckernel_orchestration.h\"\n"
784 "#include \"ckernel_alloc.h\"\n\n");
787 "#define CACHELINE_BYTES 64\n"
788 "static size_t align_up_bytes(size_t n, size_t align) {\n"
789 " if (align == 0) return n;\n"
790 " return (n + align - 1) & ~(align - 1);\n"
792 "static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align) {\n"
793 " size_t bytes = elems * elem_bytes;\n"
794 " bytes = align_up_bytes(bytes, align);\n"
795 " return bytes / elem_bytes;\n"
797 "static size_t bump_bytes(size_t *off, size_t bytes, size_t align) {\n"
798 " size_t start = align_up_bytes(*off, align);\n"
799 " *off = start + bytes;\n"
802 "static inline float *ptr_f32(uint8_t *base, size_t offset) {\n"
803 " return (float *)(base + offset);\n"
805 "static inline const float *cptr_f32(const uint8_t *base, size_t offset) {\n"
806 " return (const float *)(base + offset);\n"
810 "static inline uint8_t *ptr_u8(uint8_t *base, size_t offset) {\n"
811 " return base + offset;\n"
813 "static inline const void *cptr_void(const uint8_t *base, size_t offset) {\n"
814 " return (const void *)(base + offset);\n"
822 if (!forward || !runtime_path) {
826 const char *suffix =
".kernels";
827 size_t len = strlen(runtime_path) + strlen(suffix) + 1;
828 char *path = (
char *)malloc(len);
832 snprintf(path, len,
"%s%s", runtime_path, suffix);
834 FILE *f = fopen(path,
"wb");
836 fprintf(stderr,
"ck_codegen_emit_runtime: failed to open %s: %s\n",
837 path, strerror(errno));
843 const char **seen = (
const char **)calloc(seen_cap,
sizeof(
char *));
849 size_t seen_count = 0;
854 emit_unique_source(f,
"src/kernels/embedding_kernels.c", seen, &seen_count, seen_cap);
858 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_q8k.c", seen, &seen_count, seen_cap);
861 emit_unique_source(f,
"src/kernels/gemm_kernels_q4_0.c", seen, &seen_count, seen_cap);
862 emit_unique_source(f,
"src/kernels/gemm_kernels_q4_1.c", seen, &seen_count, seen_cap);
863 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_0.c", seen, &seen_count, seen_cap);
864 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_1.c", seen, &seen_count, seen_cap);
865 emit_unique_source(f,
"src/kernels/gemm_kernels_q8_0.c", seen, &seen_count, seen_cap);
867 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_sse.c", seen, &seen_count, seen_cap);
868 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_q8k_avx2.c", seen, &seen_count, seen_cap);
869 emit_unique_source(f,
"src/kernels/gemm_kernels_q4k_q8k_vnni.c", seen, &seen_count, seen_cap);
870 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_0_sse.c", seen, &seen_count, seen_cap);
871 emit_unique_source(f,
"src/kernels/gemm_kernels_q5_0_sse_v2.c", seen, &seen_count, seen_cap);
872 emit_unique_source(f,
"src/kernels/gemm_kernels_q6k_sse.c", seen, &seen_count, seen_cap);
873 emit_unique_source(f,
"src/kernels/quantize_row_q8_k_sse.c", seen, &seen_count, seen_cap);
878 emit_unique_source(f,
"src/kernels/gemm_fused_kernels.c", seen, &seen_count, seen_cap);
881 emit_unique_source(f,
"src/kernels/attention_decode_fused.c", seen, &seen_count, seen_cap);
882 emit_unique_source(f,
"src/kernels/attention_flash_true.c", seen, &seen_count, seen_cap);
910 fprintf(stderr,
"[ck_codegen] kernels manifest written to %s\n", path);
919 "\n/* ═══════════════════════════════════════════════════════════════\n"
920 " * C-Kernel-Engine Library API (for dlopen)\n"
921 " * ═══════════════════════════════════════════════════════════════ */\n\n"
923 "#define CK_EXPORT __declspec(dllexport)\n"
925 "#define CK_EXPORT __attribute__((visibility(\"default\")))\n"
929 " int hidden_size;\n"
930 " int intermediate_size;\n"
932 " int num_kv_heads;\n"
934 " int context_window;\n"
935 " float rms_norm_eps;\n"
936 " float rope_theta;\n"
938 "static TransformerModel g_model = {0};\n"
939 "static int g_initialized = 0;\n"
940 "static int g_fuse_swiglu_decode = -2;\n"
941 "static int g_fuse_attn_decode = -2;\n\n"
942 "static int ck_fuse_swiglu_decode_mode(void)\n"
944 " if (g_fuse_swiglu_decode != -2) return g_fuse_swiglu_decode;\n"
945 " const char *env = getenv(\"CK_FUSE_SWIGLU_DECODE\");\n"
946 " if (!env || !env[0]) {\n"
947 " g_fuse_swiglu_decode = -1; /* auto */\n"
948 " return g_fuse_swiglu_decode;\n"
950 " if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' || env[0] == 'f' || env[0] == 'F') {\n"
951 " g_fuse_swiglu_decode = 0;\n"
953 " g_fuse_swiglu_decode = 1;\n"
955 " return g_fuse_swiglu_decode;\n"
959 "static int ck_fuse_attn_decode_mode(void)\n"
961 " if (g_fuse_attn_decode != -2) return g_fuse_attn_decode;\n"
962 " const char *env = getenv(\"CK_FUSE_ATTN_DECODE\");\n"
963 " if (!env || !env[0]) {\n"
964 " g_fuse_attn_decode = -1; /* auto */\n"
965 " return g_fuse_attn_decode;\n"
967 " if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' || env[0] == 'f' || env[0] == 'F') {\n"
968 " g_fuse_attn_decode = 0;\n"
970 " g_fuse_attn_decode = 1;\n"
972 " return g_fuse_attn_decode;\n"
976 "static int run_model_decode(TransformerModel *m, int32_t token)\n"
978 " if (!m || !m->memory_base) return -1;\n"
979 " /* KV-cache decode is an inference-only fast path; training uses the full forward/backward graph. */\n"
980 " if (m->training_enabled) return -4;\n"
981 " if (!m->kv_cache_enabled) return -2;\n"
983 " int cache_cap = m->kv_cache_capacity > 0 ? m->kv_cache_capacity : m->context_window;\n"
984 " if (cache_cap > m->context_window) cache_cap = m->context_window;\n"
985 " int t = m->kv_cache_tokens;\n"
986 " if (t < 0) t = 0;\n"
987 " if (t >= cache_cap) return -3;\n"
989 " embed_token_at(m, token, t);\n"
991 " uint8_t *base = m->memory_base;\n"
992 " float *current = ptr_f32(base, m->embedded_input_offset);\n"
993 " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
994 " int fuse_swiglu_mode = ck_fuse_swiglu_decode_mode();\n"
995 " int use_fused_swiglu = 0;\n"
996 " if (fuse_swiglu_mode > 0) {\n"
997 " use_fused_swiglu = 1;\n"
998 " } else if (fuse_swiglu_mode == 0) {\n"
999 " use_fused_swiglu = 0;\n"
1000 " } else if (!ck_strict_parity_enabled()) {\n"
1001 " use_fused_swiglu = 1;\n"
1003 " int fuse_attn_mode = ck_fuse_attn_decode_mode();\n"
1004 " int use_fused_attn = 0;\n"
1005 " if (fuse_attn_mode > 0) {\n"
1006 " use_fused_attn = 1;\n"
1007 " } else if (fuse_attn_mode == 0) {\n"
1008 " use_fused_attn = 0;\n"
1009 " } else if (!ck_strict_parity_enabled()) {\n"
1010 " use_fused_attn = 1;\n"
1012 " if (m->weights_quantized) {\n"
1013 " use_fused_swiglu = 0;\n"
1014 " use_fused_attn = 0;\n"
1017 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1018 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1019 " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1020 " CKLayerForwardParamsQ4K p = {0};\n"
1021 " p.tokens = cache_cap;\n"
1022 " p.embed_dim = m->embed_dim;\n"
1023 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1024 " p.num_heads = m->num_attention_heads;\n"
1025 " p.num_kv_heads = m->num_kv_heads;\n"
1026 " p.head_dim = m->head_dim;\n"
1027 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1028 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1029 " p.intermediate_dim = m->intermediate_size;\n"
1030 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1031 " p.eps = m->rms_norm_eps;\n"
1032 " p.rope_pos_offset = t;\n"
1033 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1034 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1035 " p.input = current;\n"
1036 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1037 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1038 " p.wq = cptr_void(base, L->wq_offset);\n"
1039 " p.bq = cptr_f32(base, L->bq_offset);\n"
1040 " p.wk = cptr_void(base, L->wk_offset);\n"
1041 " p.bk = cptr_f32(base, L->bk_offset);\n"
1042 " p.wv = cptr_void(base, L->wv_offset);\n"
1043 " p.bv = cptr_f32(base, L->bv_offset);\n"
1044 " p.wo = cptr_void(base, L->wo_offset);\n"
1045 " p.bo = cptr_f32(base, L->bo_offset);\n"
1046 " p.w1 = cptr_void(base, L->w1_offset);\n"
1047 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1048 " p.w2 = cptr_void(base, L->w2_offset);\n"
1049 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1050 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1051 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1052 " p.k = ptr_f32(base, L->k_offset);\n"
1053 " p.v = ptr_f32(base, L->v_offset);\n"
1054 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1055 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1056 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1057 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1058 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1059 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1060 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1061 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1062 " p.output = ptr_f32(base, L->output_offset);\n"
1064 " ck_layer_forward_rmsnorm_swiglu_decode_q4_k(&p, t, cache_cap);\n"
1065 " } else if (m->weights_quantized) {\n"
1066 " CKLayerForwardParamsQ4K p = {0};\n"
1067 " p.tokens = cache_cap;\n"
1068 " p.embed_dim = m->embed_dim;\n"
1069 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1070 " p.num_heads = m->num_attention_heads;\n"
1071 " p.num_kv_heads = m->num_kv_heads;\n"
1072 " p.head_dim = m->head_dim;\n"
1073 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1074 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1075 " p.intermediate_dim = m->intermediate_size;\n"
1076 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1077 " p.eps = m->rms_norm_eps;\n"
1078 " p.rope_pos_offset = t;\n"
1079 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1080 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1081 " p.input = current;\n"
1082 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1083 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1084 " p.wq = cptr_void(base, L->wq_offset);\n"
1085 " p.bq = cptr_f32(base, L->bq_offset);\n"
1086 " p.wk = cptr_void(base, L->wk_offset);\n"
1087 " p.bk = cptr_f32(base, L->bk_offset);\n"
1088 " p.wv = cptr_void(base, L->wv_offset);\n"
1089 " p.bv = cptr_f32(base, L->bv_offset);\n"
1090 " p.wo = cptr_void(base, L->wo_offset);\n"
1091 " p.bo = cptr_f32(base, L->bo_offset);\n"
1092 " p.w1 = cptr_void(base, L->w1_offset);\n"
1093 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1094 " p.w2 = cptr_void(base, L->w2_offset);\n"
1095 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1096 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1097 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1098 " p.k = ptr_f32(base, L->k_offset);\n"
1099 " p.v = ptr_f32(base, L->v_offset);\n"
1100 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1101 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1102 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1103 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1104 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1105 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1106 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1107 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1108 " p.output = ptr_f32(base, L->output_offset);\n"
1109 " p.wq_dtype = L->wq_dtype;\n"
1110 " p.wk_dtype = L->wk_dtype;\n"
1111 " p.wv_dtype = L->wv_dtype;\n"
1112 " p.wo_dtype = L->wo_dtype;\n"
1113 " p.w1_dtype = L->w1_dtype;\n"
1114 " p.w2_dtype = L->w2_dtype;\n"
1116 " ck_layer_forward_rmsnorm_swiglu_decode_quant(&p, t, cache_cap);\n"
1118 " CKLayerForwardParams p = {0};\n"
1119 " p.tokens = cache_cap;\n"
1120 " p.embed_dim = m->embed_dim;\n"
1121 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1122 " p.num_heads = m->num_attention_heads;\n"
1123 " p.num_kv_heads = m->num_kv_heads;\n"
1124 " p.head_dim = m->head_dim;\n"
1125 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1126 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1127 " p.intermediate_dim = m->intermediate_size;\n"
1128 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1129 " p.eps = m->rms_norm_eps;\n"
1130 " p.rope_pos_offset = t;\n"
1131 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1132 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1133 " p.input = current;\n"
1134 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1135 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1136 " p.wq = cptr_f32(base, L->wq_offset);\n"
1137 " p.bq = cptr_f32(base, L->bq_offset);\n"
1138 " p.wk = cptr_f32(base, L->wk_offset);\n"
1139 " p.bk = cptr_f32(base, L->bk_offset);\n"
1140 " p.wv = cptr_f32(base, L->wv_offset);\n"
1141 " p.bv = cptr_f32(base, L->bv_offset);\n"
1142 " p.wo = cptr_f32(base, L->wo_offset);\n"
1143 " p.bo = cptr_f32(base, L->bo_offset);\n"
1144 " p.w1 = cptr_f32(base, L->w1_offset);\n"
1145 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1146 " p.w2 = cptr_f32(base, L->w2_offset);\n"
1147 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1148 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1149 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1150 " p.k = ptr_f32(base, L->k_offset);\n"
1151 " p.v = ptr_f32(base, L->v_offset);\n"
1152 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1153 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1154 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1155 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1156 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1157 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1158 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1159 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1160 " p.output = ptr_f32(base, L->output_offset);\n"
1162 " if (use_fused_attn) {\n"
1163 " if (use_fused_swiglu) {\n"
1164 " ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(&p, t, cache_cap);\n"
1166 " ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(&p, t, cache_cap);\n"
1168 " } else if (use_fused_swiglu) {\n"
1169 " ck_layer_forward_rmsnorm_swiglu_decode_fused(&p, t, cache_cap);\n"
1171 " ck_layer_forward_rmsnorm_swiglu_decode(&p, t, cache_cap);\n"
1174 " current = ptr_f32(base, L->output_offset);\n"
1177 " int V = m->vocab_size;\n"
1178 " int D = m->embed_dim;\n"
1179 " int aligned_D = (int)m->aligned_embed_dim;\n"
1180 " float *final_in = current + (size_t)t * aligned_D;\n"
1181 " float *final_out = ptr_f32(base, m->final_output_offset) + (size_t)t * aligned_D;\n"
1182 " float *final_rstd = ptr_f32(base, m->final_ln_rstd_offset) + (size_t)t;\n"
1184 " rmsnorm_forward(final_in,\n"
1185 " cptr_f32(base, m->final_ln_weight_offset),\n"
1191 " m->rms_norm_eps);\n"
1193 " float *logits_row = ptr_f32(base, m->logits_offset) + (size_t)t * (size_t)V;\n"
1194 " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1195 " gemm_nt_q4_k(final_out,\n"
1196 " cptr_void(base, m->lm_head_weight_offset),\n"
1202 " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1203 " gemm_nt_q6_k(final_out,\n"
1204 " cptr_void(base, m->lm_head_weight_offset),\n"
1211 " lm_head_forward(final_out,\n"
1212 " cptr_f32(base, m->lm_head_weight_offset),\n"
1221 " m->kv_cache_tokens = t + 1;\n"
1222 " m->active_tokens = m->kv_cache_tokens;\n"
1228 "CK_EXPORT int ck_model_init(const char *weights_path)\n"
1230 " if (g_initialized) return 0;\n"
1231 " memcpy(g_model.magic, \"BUMPWGT3\", 8);\n"
1232 " g_model.version = 3;\n"
1233 " g_model.model_type = 0;\n"
1234 " g_model.num_layers = %d;\n"
1235 " g_model.embed_dim = %d;\n"
1236 " g_model.intermediate_size = %d;\n"
1237 " g_model.num_attention_heads = %d;\n"
1238 " g_model.num_kv_heads = %d;\n"
1239 " g_model.vocab_size = %d;\n"
1240 " g_model.context_window = %d;\n"
1241 " g_model.rms_norm_eps = (float)%.9g;\n"
1242 " g_model.rope_theta = (float)%.9g;\n"
1243 " g_model.num_cores = 1;\n"
1244 " g_model.task_type = TASK_LM;\n"
1245 " g_model.weight_dtype = CK_DT_FP32;\n"
1246 " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
1248 " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
1249 " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
1250 " g_model.weight_dtype = CK_DT_Q4_K;\n"
1251 " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
1252 " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
1253 " g_model.weight_dtype = CK_DT_Q6_K;\n"
1256 " init_weight_dtypes_uniform(&g_model, g_model.weight_dtype);\n"
1257 " refresh_weight_flags(&g_model);\n"
1258 " /* Check env var to pre-allocate gradient buffers for training */\n"
1259 " const char *train_env = getenv(\"CK_ENABLE_TRAINING\");\n"
1260 " if (train_env && (train_env[0] == '1' || train_env[0] == 'y' || train_env[0] == 'Y')) {\n"
1261 " g_model.training_enabled = true;\n"
1262 " g_model.learning_rate = 1e-4f;\n"
1264 " if (weights_path) {\n"
1265 " int dtype_rc = load_weight_dtypes(weights_path, &g_model);\n"
1266 " if (dtype_rc < 0) {\n"
1267 " fprintf(stderr, \"Failed to read weight dtype table from %%s\\n\", weights_path);\n"
1271 " if (g_model.training_enabled && g_model.weights_quantized) {\n"
1272 " fprintf(stderr, \"Quantized weights are inference-only; disable training for this model\\n\");\n"
1275 " g_model.kv_cache_enabled = false;\n"
1276 " g_model.kv_cache_capacity = g_model.context_window;\n"
1277 " g_model.kv_cache_tokens = 0;\n"
1278 " if (layout_model(&g_model) != 0) return -1;\n"
1279 " if (weights_path) {\n"
1280 " if (load_model_weights(weights_path, &g_model) != 0) return -2;\n"
1282 " g_initialized = 1;\n"
1297 "CK_EXPORT void ck_model_get_info(CKModelInfo *info)\n"
1299 " if (!info) return;\n"
1300 " info->num_layers = g_model.num_layers;\n"
1301 " info->hidden_size = g_model.embed_dim;\n"
1302 " info->intermediate_size = g_model.intermediate_size;\n"
1303 " info->num_heads = g_model.num_attention_heads;\n"
1304 " info->num_kv_heads = g_model.num_kv_heads;\n"
1305 " info->vocab_size = g_model.vocab_size;\n"
1306 " info->context_window = g_model.context_window;\n"
1307 " info->rms_norm_eps = g_model.rms_norm_eps;\n"
1308 " info->rope_theta = g_model.rope_theta;\n"
1313 "CK_EXPORT int ck_model_embed_tokens(const int32_t *tokens, int num_tokens)\n"
1315 " if (!g_initialized) return -1;\n"
1316 " int cap = g_model.context_window;\n"
1317 " if (g_model.kv_cache_enabled && g_model.kv_cache_capacity > 0 && g_model.kv_cache_capacity < cap) {\n"
1318 " cap = g_model.kv_cache_capacity;\n"
1320 " if (num_tokens > cap) num_tokens = cap;\n"
1321 " if (num_tokens < 1) num_tokens = 1;\n"
1322 " g_model.active_tokens = num_tokens;\n"
1323 " if (g_model.kv_cache_enabled && !g_model.training_enabled) {\n"
1324 " g_model.kv_cache_tokens = 0;\n"
1326 " embed_tokens(&g_model, tokens, num_tokens);\n"
1332 "CK_EXPORT int ck_model_forward(float *logits_out)\n"
1334 " if (!g_initialized) return -1;\n"
1335 " run_model_forward(&g_model);\n"
1336 " if (g_model.kv_cache_enabled && !g_model.training_enabled) {\n"
1337 " g_model.kv_cache_tokens = g_model.active_tokens;\n"
1339 " if (logits_out && g_model.vocab_size > 0) {\n"
1340 " size_t n = (size_t)g_model.active_tokens * (size_t)g_model.vocab_size;\n"
1341 " memcpy(logits_out, ptr_f32(g_model.memory_base, g_model.logits_offset), n * sizeof(float));\n"
1348 "CK_EXPORT int ck_model_kv_cache_enable(int capacity)\n"
1350 " if (!g_initialized) return -1;\n"
1351 " if (g_model.training_enabled) return -4;\n"
1352 " g_model.kv_cache_enabled = true;\n"
1353 " int cap = capacity;\n"
1354 " if (cap <= 0 || cap > g_model.context_window) cap = g_model.context_window;\n"
1355 " g_model.kv_cache_capacity = cap;\n"
1356 " g_model.kv_cache_tokens = 0;\n"
1357 " g_model.active_tokens = 0;\n"
1360 "CK_EXPORT void ck_model_kv_cache_reset(void)\n"
1362 " if (!g_initialized) return;\n"
1363 " g_model.kv_cache_tokens = 0;\n"
1364 " g_model.active_tokens = 0;\n"
1366 "CK_EXPORT int ck_model_kv_cache_get_tokens(void)\n"
1368 " return g_initialized ? g_model.kv_cache_tokens : 0;\n"
1370 "CK_EXPORT int ck_model_decode(int32_t token, float *logits_out)\n"
1372 " if (!g_initialized) return -1;\n"
1373 " if (g_model.training_enabled) return -4;\n"
1374 " int ret = run_model_decode(&g_model, token);\n"
1375 " if (ret != 0) return ret;\n"
1376 " if (logits_out && g_model.vocab_size > 0) {\n"
1377 " int t = g_model.active_tokens - 1;\n"
1378 " memcpy(logits_out,\n"
1379 " ptr_f32(g_model.memory_base, g_model.logits_offset) + (size_t)t * (size_t)g_model.vocab_size,\n"
1380 " (size_t)g_model.vocab_size * sizeof(float));\n"
1387 "CK_EXPORT float* ck_model_get_logits(void)\n"
1389 " if (!g_initialized) return NULL;\n"
1390 " return ptr_f32(g_model.memory_base, g_model.logits_offset);\n"
1395 "CK_EXPORT int ck_model_backward(const int32_t *tokens, const int32_t *targets, float *loss_out)\n"
1397 " if (!g_initialized) return -1;\n"
1398 " return run_model_backward(&g_model, tokens, targets, loss_out);\n"
1403 "CK_EXPORT void ck_model_free(void)\n"
1405 " if (!g_initialized) return;\n"
1406 " if (g_model.memory_base) ck_huge_free(g_model.memory_base, g_model.total_bytes);\n"
1407 " if (g_model.layers) free(g_model.layers);\n"
1408 " memset(&g_model, 0, sizeof(g_model));\n"
1409 " g_initialized = 0;\n"
1414 "CK_EXPORT int ck_model_get_context_window(void) { return g_initialized ? g_model.context_window : 0; }\n"
1415 "CK_EXPORT int ck_model_get_vocab_size(void) { return g_initialized ? g_model.vocab_size : 0; }\n"
1416 "CK_EXPORT int ck_model_get_hidden_size(void) { return g_initialized ? g_model.embed_dim : 0; }\n"
1417 "CK_EXPORT int ck_model_get_active_tokens(void) { return g_initialized ? g_model.active_tokens : 0; }\n"
1418 "CK_EXPORT int ck_model_is_training_enabled(void) { return g_initialized ? g_model.training_enabled : 0; }\n"
1419 "CK_EXPORT void ck_model_set_learning_rate(float lr) { if (g_initialized) g_model.learning_rate = lr; }\n"
1420 "CK_EXPORT float ck_model_get_learning_rate(void) { return g_initialized ? g_model.learning_rate : 0.0f; }\n\n"
1421 "CK_EXPORT int ck_model_enable_training(float learning_rate)\n"
1423 " if (!g_initialized) return -1;\n"
1424 " g_model.training_enabled = true;\n"
1425 " g_model.learning_rate = learning_rate;\n"
1428 "CK_EXPORT void ck_model_disable_training(void)\n"
1430 " if (g_initialized) g_model.training_enabled = false;\n"
1432 "CK_EXPORT void ck_model_optimizer_step(void)\n"
1434 " if (!g_initialized || !g_model.training_enabled) return;\n"
1435 " sgd_update(&g_model, g_model.learning_rate);\n"
1441 if (!forward || !path) {
1448 FILE *out = fopen(path,
"wb");
1450 fprintf(stderr,
"ck_codegen_emit_runtime: failed to open %s: %s\n",
1451 path, strerror(errno));
1463 " TASK_SEQ_CLS = 1\n"
1466 " OPTIMIZER_SGD = 0,\n"
1467 " OPTIMIZER_ADAM = 1\n"
1468 "} OptimizerType;\n\n"
1469 "typedef struct {\n"
1470 " size_t total_gradient_floats;\n"
1471 "} GradientStorage;\n\n");
1477 "static int ensure_layers_allocated(TransformerModel *m)\n"
1479 " if (!m) return -1;\n"
1480 " if (!m->layers && m->num_layers > 0) {\n"
1481 " m->layers = (TrulyOptimalLayer *)calloc((size_t)m->num_layers, sizeof(TrulyOptimalLayer));\n"
1482 " if (!m->layers) return -1;\n"
1486 "static void init_weight_dtypes_uniform(TransformerModel *m, CKDataType dt)\n"
1488 " if (!m) return;\n"
1489 " m->token_emb_dtype = dt;\n"
1490 " m->lm_head_weight_dtype = dt;\n"
1491 " m->pos_emb_dtype = CK_DT_FP32;\n"
1492 " if (ensure_layers_allocated(m) != 0) return;\n"
1493 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1494 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1495 " L->wq_dtype = dt;\n"
1496 " L->wk_dtype = dt;\n"
1497 " L->wv_dtype = dt;\n"
1498 " L->wo_dtype = dt;\n"
1499 " L->w1_dtype = dt;\n"
1500 " L->w2_dtype = dt;\n"
1503 "static void refresh_weight_flags(TransformerModel *m)\n"
1505 " if (!m) return;\n"
1506 " CKDataType base = m->token_emb_dtype;\n"
1508 " int quant = ck_dtype_is_quantized(base);\n"
1509 " if (m->lm_head_weight_dtype != base) mixed = 1;\n"
1510 " if (ck_dtype_is_quantized(m->lm_head_weight_dtype)) quant = 1;\n"
1511 " if (m->layers) {\n"
1512 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1513 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1514 " if (L->wq_dtype != base || L->wk_dtype != base || L->wv_dtype != base ||\n"
1515 " L->wo_dtype != base || L->w1_dtype != base || L->w2_dtype != base) {\n"
1518 " if (ck_dtype_is_quantized(L->wq_dtype) || ck_dtype_is_quantized(L->wk_dtype) ||\n"
1519 " ck_dtype_is_quantized(L->wv_dtype) || ck_dtype_is_quantized(L->wo_dtype) ||\n"
1520 " ck_dtype_is_quantized(L->w1_dtype) || ck_dtype_is_quantized(L->w2_dtype)) {\n"
1525 " m->weights_mixed = mixed ? true : false;\n"
1526 " m->weights_quantized = quant ? true : false;\n"
1528 " m->weight_dtype = base;\n"
1531 "static int load_weight_dtypes(const char *path, TransformerModel *m)\n"
1533 " if (!path || !m) return -1;\n"
1534 " FILE *f = fopen(path, \"rb\");\n"
1535 " if (!f) return -1;\n"
1537 " if (fread(magic, 1, 8, f) != 8) {\n"
1541 " if (memcmp(magic, \"BUMPWGT3\", 8) != 0) {\n"
1545 " uint32_t version = 0;\n"
1546 " if (fread(&version, sizeof(uint32_t), 1, f) != 1) {\n"
1550 " if (version < 3) {\n"
1554 " if (fseek(f, 128, SEEK_SET) != 0) {\n"
1558 " uint32_t dtype_len = 0;\n"
1559 " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) {\n"
1563 " if (dtype_len == 0) {\n"
1567 " uint8_t *dtype_buf = (uint8_t *)malloc(dtype_len);\n"
1568 " if (!dtype_buf) {\n"
1572 " if (fread(dtype_buf, 1, dtype_len, f) != dtype_len) {\n"
1573 " free(dtype_buf);\n"
1579 " size_t expected = (size_t)m->num_layers * 14u + 4u;\n"
1580 " if (dtype_len != expected) {\n"
1581 " free(dtype_buf);\n"
1584 " if (ensure_layers_allocated(m) != 0) {\n"
1585 " free(dtype_buf);\n"
1589 " size_t idx = 0;\n"
1590 " CKDataType token_dt = (CKDataType)dtype_buf[idx++];\n"
1591 " CKDataType pos_dt = (CKDataType)dtype_buf[idx++];\n"
1592 " if (pos_dt != CK_DT_FP32) {\n"
1593 " free(dtype_buf);\n"
1596 " if (token_dt != CK_DT_FP32 && token_dt != CK_DT_Q4_K && token_dt != CK_DT_Q6_K) {\n"
1597 " free(dtype_buf);\n"
1600 " m->token_emb_dtype = token_dt;\n"
1601 " m->lm_head_weight_dtype = token_dt;\n"
1602 " m->pos_emb_dtype = pos_dt;\n"
1604 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1605 " CKDataType ln1_dt = (CKDataType)dtype_buf[idx++];\n"
1606 " CKDataType ln2_dt = (CKDataType)dtype_buf[idx++];\n"
1607 " CKDataType wq_dt = (CKDataType)dtype_buf[idx++];\n"
1608 " CKDataType bq_dt = (CKDataType)dtype_buf[idx++];\n"
1609 " CKDataType wk_dt = (CKDataType)dtype_buf[idx++];\n"
1610 " CKDataType bk_dt = (CKDataType)dtype_buf[idx++];\n"
1611 " CKDataType wv_dt = (CKDataType)dtype_buf[idx++];\n"
1612 " CKDataType bv_dt = (CKDataType)dtype_buf[idx++];\n"
1613 " CKDataType wo_dt = (CKDataType)dtype_buf[idx++];\n"
1614 " CKDataType bo_dt = (CKDataType)dtype_buf[idx++];\n"
1615 " CKDataType w1_dt = (CKDataType)dtype_buf[idx++];\n"
1616 " CKDataType b1_dt = (CKDataType)dtype_buf[idx++];\n"
1617 " CKDataType w2_dt = (CKDataType)dtype_buf[idx++];\n"
1618 " CKDataType b2_dt = (CKDataType)dtype_buf[idx++];\n"
1620 " if (ln1_dt != CK_DT_FP32 || ln2_dt != CK_DT_FP32 ||\n"
1621 " bq_dt != CK_DT_FP32 || bk_dt != CK_DT_FP32 ||\n"
1622 " bv_dt != CK_DT_FP32 || bo_dt != CK_DT_FP32 ||\n"
1623 " b1_dt != CK_DT_FP32 || b2_dt != CK_DT_FP32) {\n"
1624 " free(dtype_buf);\n"
1627 " if ((wq_dt != CK_DT_FP32 && wq_dt != CK_DT_Q4_K && wq_dt != CK_DT_Q6_K) ||\n"
1628 " (wk_dt != CK_DT_FP32 && wk_dt != CK_DT_Q4_K && wk_dt != CK_DT_Q6_K) ||\n"
1629 " (wv_dt != CK_DT_FP32 && wv_dt != CK_DT_Q4_K && wv_dt != CK_DT_Q6_K) ||\n"
1630 " (wo_dt != CK_DT_FP32 && wo_dt != CK_DT_Q4_K && wo_dt != CK_DT_Q6_K) ||\n"
1631 " (w1_dt != CK_DT_FP32 && w1_dt != CK_DT_Q4_K && w1_dt != CK_DT_Q6_K) ||\n"
1632 " (w2_dt != CK_DT_FP32 && w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K)) {\n"
1633 " free(dtype_buf);\n"
1637 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1638 " L->wq_dtype = wq_dt;\n"
1639 " L->wk_dtype = wk_dt;\n"
1640 " L->wv_dtype = wv_dt;\n"
1641 " L->wo_dtype = wo_dt;\n"
1642 " L->w1_dtype = w1_dt;\n"
1643 " L->w2_dtype = w2_dt;\n"
1646 " CKDataType final_norm_dt = (CKDataType)dtype_buf[idx++];\n"
1647 " CKDataType final_bias_dt = (CKDataType)dtype_buf[idx++];\n"
1648 " free(dtype_buf);\n"
1649 " if (final_norm_dt != CK_DT_FP32 || final_bias_dt != CK_DT_FP32) {\n"
1653 " refresh_weight_flags(m);\n"
1657 "static int layout_model(TransformerModel *m)\n"
1659 " if (!m) return -1;\n"
1660 " if (m->num_attention_heads <= 0 || m->embed_dim <= 0) return -1;\n"
1661 " if (m->num_kv_heads <= 0) m->num_kv_heads = m->num_attention_heads;\n"
1662 " if (m->num_attention_heads %% m->num_kv_heads != 0) return -1;\n"
1663 " if (m->context_window <= 0) m->context_window = 1;\n"
1664 " if (m->vocab_size <= 0) m->vocab_size = 1;\n"
1665 " if (m->intermediate_size <= 0) return -1;\n"
1666 " m->head_dim = m->embed_dim / m->num_attention_heads;\n"
1667 " if (m->rms_norm_eps <= 0.0f) m->rms_norm_eps = 1e-5f;\n"
1668 " if (m->rope_theta < 0.0f) m->rope_theta = 0.0f;\n"
1669 " if (m->rope_theta > 0.0f && (m->head_dim %% 2 != 0)) return -1;\n"
1670 " if (m->elem_bytes == 0) m->elem_bytes = sizeof(float);\n"
1671 " size_t elem_bytes = m->elem_bytes;\n"
1672 " m->aligned_embed_dim = align_up_elems((size_t)m->embed_dim, elem_bytes, CACHELINE_BYTES);\n"
1673 " m->aligned_head_dim = align_up_elems((size_t)m->head_dim, elem_bytes, CACHELINE_BYTES);\n"
1674 " m->aligned_attn_context_window = align_up_elems((size_t)m->context_window, elem_bytes, CACHELINE_BYTES);\n"
1675 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, elem_bytes, CACHELINE_BYTES);\n"
1676 " if (ensure_layers_allocated(m) != 0) return -1;\n"
1677 " if (m->weights_quantized) {\n"
1678 " /* K-quant weights require K dimension to be a multiple of 256. */\n"
1679 " if ((m->aligned_embed_dim %% 256) != 0) return -1;\n"
1680 " if ((aligned_intermediate_dim %% 256) != 0) return -1;\n"
1681 " int wo_quant = 0;\n"
1682 " if (m->layers) {\n"
1683 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1684 " if (ck_dtype_is_quantized(m->layers[layer].wo_dtype)) {\n"
1690 " if (wo_quant && (size_t)m->num_attention_heads * m->aligned_head_dim != m->aligned_embed_dim) return -1;\n"
1693 " if (m->num_cores <= 0) m->num_cores = 1;\n"
1694 " m->tokens_per_core = (m->context_window + m->num_cores - 1) / m->num_cores;\n"
1696 " size_t off = 0;\n");
1699 " m->layers_start_offset = off;\n"
1701 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1702 " TrulyOptimalLayer *L = &m->layers[layer];\n");
1710 " if (m->num_layers > 1) {\n"
1711 " m->layer_stride = m->layers[1].%s_offset - m->layers[0].%s_offset;\n"
1713 " m->layer_stride = 0;\n"
1715 stride_field, stride_field);
1719 " m->total_bytes = align_up_bytes(off, CACHELINE_BYTES);\n"
1720 " m->memory_base = (uint8_t *)ck_huge_alloc(m->total_bytes);\n"
1721 " if (!m->memory_base) return -1;\n"
1722 " if (m->rope_theta > 0.0f) {\n"
1723 " rope_precompute_cache(ptr_f32(m->memory_base, m->rope_cos_cache_offset),\n"
1724 " ptr_f32(m->memory_base, m->rope_sin_cache_offset),\n"
1725 " m->context_window,\n"
1727 " m->rope_theta);\n"
1733 "static void lm_head_forward(const float *hidden,\n"
1734 " const float *weights,\n"
1736 " int T, int V, int D, int aligned_D);\n"
1737 "static void lm_head_backward(const float *hidden,\n"
1738 " const float *weights,\n"
1739 " const float *d_logits,\n"
1740 " float *d_hidden,\n"
1741 " float *d_weights,\n"
1742 " int T, int V, int D, int aligned_D);\n"
1743 "static void softmax_cross_entropy(const float *logits,\n"
1744 " const int32_t *targets,\n"
1746 " float *d_logits,\n"
1747 " float *loss_out);\n\n");
1750 "static void run_model_forward(TransformerModel *m)\n"
1752 " uint8_t *base = m->memory_base;\n"
1753 " float *current = ptr_f32(base, m->embedded_input_offset);\n"
1754 " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
1755 " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
1756 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1757 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1758 " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1759 " CKLayerForwardParamsQ4K p = {0};\n"
1761 " p.embed_dim = m->embed_dim;\n"
1762 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1763 " p.num_heads = m->num_attention_heads;\n"
1764 " p.num_kv_heads = m->num_kv_heads;\n"
1765 " p.head_dim = m->head_dim;\n"
1766 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1767 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1768 " p.intermediate_dim = m->intermediate_size;\n"
1769 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1770 " p.eps = m->rms_norm_eps;\n"
1771 " p.rope_pos_offset = 0;\n"
1772 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1773 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1774 " p.input = current;\n"
1775 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1776 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1777 " p.wq = cptr_void(base, L->wq_offset);\n"
1778 " p.bq = cptr_f32(base, L->bq_offset);\n"
1779 " p.wk = cptr_void(base, L->wk_offset);\n"
1780 " p.bk = cptr_f32(base, L->bk_offset);\n"
1781 " p.wv = cptr_void(base, L->wv_offset);\n"
1782 " p.bv = cptr_f32(base, L->bv_offset);\n"
1783 " p.wo = cptr_void(base, L->wo_offset);\n"
1784 " p.bo = cptr_f32(base, L->bo_offset);\n"
1785 " p.w1 = cptr_void(base, L->w1_offset);\n"
1786 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1787 " p.w2 = cptr_void(base, L->w2_offset);\n"
1788 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1789 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1790 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1791 " p.q = ptr_f32(base, L->q_offset);\n"
1792 " p.k = ptr_f32(base, L->k_offset);\n"
1793 " p.v = ptr_f32(base, L->v_offset);\n"
1794 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1795 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1796 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1797 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1798 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1799 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1800 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1801 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1802 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1803 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1804 " p.output = ptr_f32(base, L->output_offset);\n"
1805 " ck_layer_forward_rmsnorm_swiglu_q4_k(&p);\n"
1806 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1807 " kv_cache_repack_head_major_inplace(p.k,\n"
1808 " p.num_kv_heads,\n"
1810 " m->kv_cache_capacity,\n"
1811 " p.aligned_head_dim);\n"
1812 " kv_cache_repack_head_major_inplace(p.v,\n"
1813 " p.num_kv_heads,\n"
1815 " m->kv_cache_capacity,\n"
1816 " p.aligned_head_dim);\n"
1818 " current = p.output;\n"
1819 " } else if (m->weights_quantized) {\n"
1820 " CKLayerForwardParamsQ4K p = {0};\n"
1822 " p.embed_dim = m->embed_dim;\n"
1823 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1824 " p.num_heads = m->num_attention_heads;\n"
1825 " p.num_kv_heads = m->num_kv_heads;\n"
1826 " p.head_dim = m->head_dim;\n"
1827 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1828 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1829 " p.intermediate_dim = m->intermediate_size;\n"
1830 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1831 " p.eps = m->rms_norm_eps;\n"
1832 " p.rope_pos_offset = 0;\n"
1833 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1834 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1835 " p.input = current;\n"
1836 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1837 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1838 " p.wq = cptr_void(base, L->wq_offset);\n"
1839 " p.bq = cptr_f32(base, L->bq_offset);\n"
1840 " p.wk = cptr_void(base, L->wk_offset);\n"
1841 " p.bk = cptr_f32(base, L->bk_offset);\n"
1842 " p.wv = cptr_void(base, L->wv_offset);\n"
1843 " p.bv = cptr_f32(base, L->bv_offset);\n"
1844 " p.wo = cptr_void(base, L->wo_offset);\n"
1845 " p.bo = cptr_f32(base, L->bo_offset);\n"
1846 " p.w1 = cptr_void(base, L->w1_offset);\n"
1847 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1848 " p.w2 = cptr_void(base, L->w2_offset);\n"
1849 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1850 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1851 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1852 " p.q = ptr_f32(base, L->q_offset);\n"
1853 " p.k = ptr_f32(base, L->k_offset);\n"
1854 " p.v = ptr_f32(base, L->v_offset);\n"
1855 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1856 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1857 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1858 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1859 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1860 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1861 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1862 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1863 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1864 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1865 " p.output = ptr_f32(base, L->output_offset);\n"
1866 " p.wq_dtype = L->wq_dtype;\n"
1867 " p.wk_dtype = L->wk_dtype;\n"
1868 " p.wv_dtype = L->wv_dtype;\n"
1869 " p.wo_dtype = L->wo_dtype;\n"
1870 " p.w1_dtype = L->w1_dtype;\n"
1871 " p.w2_dtype = L->w2_dtype;\n"
1872 " ck_layer_forward_rmsnorm_swiglu_quant(&p);\n"
1873 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1874 " kv_cache_repack_head_major_inplace(p.k,\n"
1875 " p.num_kv_heads,\n"
1877 " m->kv_cache_capacity,\n"
1878 " p.aligned_head_dim);\n"
1879 " kv_cache_repack_head_major_inplace(p.v,\n"
1880 " p.num_kv_heads,\n"
1882 " m->kv_cache_capacity,\n"
1883 " p.aligned_head_dim);\n"
1885 " current = p.output;\n"
1887 " CKLayerForwardParams p = {0};\n"
1889 " p.embed_dim = m->embed_dim;\n"
1890 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1891 " p.num_heads = m->num_attention_heads;\n"
1892 " p.num_kv_heads = m->num_kv_heads;\n"
1893 " p.head_dim = m->head_dim;\n"
1894 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1895 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1896 " p.intermediate_dim = m->intermediate_size;\n"
1897 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1898 " p.eps = m->rms_norm_eps;\n"
1899 " p.rope_pos_offset = 0;\n"
1900 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1901 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1902 " p.input = current;\n"
1903 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1904 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1905 " p.wq = cptr_f32(base, L->wq_offset);\n"
1906 " p.bq = cptr_f32(base, L->bq_offset);\n"
1907 " p.wk = cptr_f32(base, L->wk_offset);\n"
1908 " p.bk = cptr_f32(base, L->bk_offset);\n"
1909 " p.wv = cptr_f32(base, L->wv_offset);\n"
1910 " p.bv = cptr_f32(base, L->bv_offset);\n"
1911 " p.wo = cptr_f32(base, L->wo_offset);\n"
1912 " p.bo = cptr_f32(base, L->bo_offset);\n"
1913 " p.w1 = cptr_f32(base, L->w1_offset);\n"
1914 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1915 " p.w2 = cptr_f32(base, L->w2_offset);\n"
1916 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1917 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1918 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1919 " p.q = ptr_f32(base, L->q_offset);\n"
1920 " p.k = ptr_f32(base, L->k_offset);\n"
1921 " p.v = ptr_f32(base, L->v_offset);\n"
1922 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1923 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1924 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1925 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1926 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1927 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1928 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1929 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1930 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1931 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1932 " p.output = ptr_f32(base, L->output_offset);\n"
1933 " ck_layer_forward_rmsnorm_swiglu(&p);\n"
1934 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1935 " kv_cache_repack_head_major_inplace(p.k,\n"
1936 " p.num_kv_heads,\n"
1938 " m->kv_cache_capacity,\n"
1939 " p.aligned_head_dim);\n"
1940 " kv_cache_repack_head_major_inplace(p.v,\n"
1941 " p.num_kv_heads,\n"
1943 " m->kv_cache_capacity,\n"
1944 " p.aligned_head_dim);\n"
1946 " current = p.output;\n"
1949 " float *final_out = ptr_f32(base, m->final_output_offset);\n"
1950 " rmsnorm_forward(current,\n"
1951 " cptr_f32(base, m->final_ln_weight_offset),\n"
1953 " ptr_f32(base, m->final_ln_rstd_offset),\n"
1956 " (int)m->aligned_embed_dim,\n"
1957 " m->rms_norm_eps);\n"
1958 " if (m->vocab_size > 0) {\n"
1959 " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1960 " gemm_nt_q4_k(final_out,\n"
1961 " cptr_void(base, m->lm_head_weight_offset),\n"
1963 " ptr_f32(base, m->logits_offset),\n"
1966 " (int)m->aligned_embed_dim);\n"
1967 " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1968 " gemm_nt_q6_k(final_out,\n"
1969 " cptr_void(base, m->lm_head_weight_offset),\n"
1971 " ptr_f32(base, m->logits_offset),\n"
1974 " (int)m->aligned_embed_dim);\n"
1976 " lm_head_forward(final_out,\n"
1977 " cptr_f32(base, m->lm_head_weight_offset),\n"
1978 " ptr_f32(base, m->logits_offset),\n"
1982 " (int)m->aligned_embed_dim);\n"
1991 "static int run_model_backward(TransformerModel *m,\n"
1992 " const int32_t *tokens,\n"
1993 " const int32_t *targets,\n"
1994 " float *loss_out)\n"
1996 " if (!m || !m->training_enabled) return 0;\n"
1997 " if (!tokens || !targets) return -1;\n"
1998 " if (m->num_layers <= 0) return -1;\n"
1999 " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
2000 " int V = m->vocab_size;\n"
2001 " int D = m->embed_dim;\n"
2002 " int aligned_D = (int)m->aligned_embed_dim;\n"
2003 " uint8_t *base = m->memory_base;\n"
2007 " float *final_out = ptr_f32(base, m->final_output_offset);\n"
2008 " float *logits = ptr_f32(base, m->logits_offset);\n"
2009 " float *d_logits = ptr_f32(base, m->d_logits_offset);\n"
2010 " float *d_final_out = ptr_f32(base, m->d_final_output_offset);\n"
2011 " float *d_final_in = ptr_f32(base, m->d_final_input_offset);\n"
2013 " float loss = 0.0f;\n"
2014 " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2015 " if (loss_out) {\n"
2016 " *loss_out = loss;\n"
2018 " lm_head_backward(final_out,\n"
2019 " cptr_f32(base, m->lm_head_weight_offset),\n"
2022 " ptr_f32(base, m->d_token_emb_offset),\n"
2023 " T, V, D, aligned_D);\n"
2024 " rmsnorm_backward(d_final_out,\n"
2025 " ptr_f32(base, m->layers[m->num_layers - 1].output_offset),\n"
2026 " cptr_f32(base, m->final_ln_weight_offset),\n"
2027 " ptr_f32(base, m->final_ln_rstd_offset),\n"
2029 " ptr_f32(base, m->d_final_ln_weight_offset),\n"
2030 " T, D, aligned_D);\n"
2032 " for (int layer = m->num_layers - 1; layer >= 0; --layer) {\n"
2033 " TrulyOptimalLayer *L = &m->layers[layer];\n"
2034 " CKLayerBackwardParams p = {0};\n"
2036 " p.embed_dim = m->embed_dim;\n"
2037 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
2038 " p.num_heads = m->num_attention_heads;\n"
2039 " p.num_kv_heads = m->num_kv_heads;\n"
2040 " p.head_dim = m->head_dim;\n"
2041 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
2042 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
2043 " p.intermediate_dim = m->intermediate_size;\n"
2044 " p.aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2045 " p.eps = m->rms_norm_eps;\n"
2046 " p.rope_pos_offset = 0;\n"
2047 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
2048 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
2049 " p.input = (layer == 0) ? ptr_f32(base, m->embedded_input_offset)\n"
2050 " : ptr_f32(base, m->layers[layer - 1].output_offset);\n"
2051 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
2052 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
2053 " p.ln1_out = cptr_f32(base, L->ln1_out_offset);\n"
2054 " p.ln1_rstd = cptr_f32(base, L->ln1_rstd_offset);\n"
2055 " p.ln2_out = cptr_f32(base, L->ln2_out_offset);\n"
2056 " p.ln2_rstd = cptr_f32(base, L->ln2_rstd_offset);\n"
2057 " p.wq = cptr_f32(base, L->wq_offset);\n"
2058 " p.bq = cptr_f32(base, L->bq_offset);\n"
2059 " p.wk = cptr_f32(base, L->wk_offset);\n"
2060 " p.bk = cptr_f32(base, L->bk_offset);\n"
2061 " p.wv = cptr_f32(base, L->wv_offset);\n"
2062 " p.bv = cptr_f32(base, L->bv_offset);\n"
2063 " p.wo = cptr_f32(base, L->wo_offset);\n"
2064 " p.bo = cptr_f32(base, L->bo_offset);\n"
2065 " p.w1 = cptr_f32(base, L->w1_offset);\n"
2066 " p.b1 = cptr_f32(base, L->b1_offset);\n"
2067 " p.w2 = cptr_f32(base, L->w2_offset);\n"
2068 " p.b2 = cptr_f32(base, L->b2_offset);\n"
2069 " p.q = cptr_f32(base, L->q_offset);\n"
2070 " p.k = cptr_f32(base, L->k_offset);\n"
2071 " p.v = cptr_f32(base, L->v_offset);\n"
2072 " p.scores = L->scores_offset ? cptr_f32(base, L->scores_offset) : NULL;\n"
2073 " p.attn_out = cptr_f32(base, L->attn_out_offset);\n"
2074 " p.residual1 = cptr_f32(base, L->residual1_offset);\n"
2075 " p.fc1_out = cptr_f32(base, L->fc1_out_offset);\n"
2076 " p.swiglu_out = cptr_f32(base, L->swiglu_out_offset);\n"
2077 " p.d_output = ptr_f32(base, L->d_output_offset);\n"
2078 " p.d_input = ptr_f32(base, L->d_input_offset);\n"
2079 " p.d_ln1_gamma = ptr_f32(base, L->d_ln1_gamma_offset);\n"
2080 " p.d_ln2_gamma = ptr_f32(base, L->d_ln2_gamma_offset);\n"
2081 " p.d_wq = ptr_f32(base, L->d_wq_offset);\n"
2082 " p.d_bq = ptr_f32(base, L->d_bq_offset);\n"
2083 " p.d_wk = ptr_f32(base, L->d_wk_offset);\n"
2084 " p.d_bk = ptr_f32(base, L->d_bk_offset);\n"
2085 " p.d_wv = ptr_f32(base, L->d_wv_offset);\n"
2086 " p.d_bv = ptr_f32(base, L->d_bv_offset);\n"
2087 " p.d_wo = ptr_f32(base, L->d_wo_offset);\n"
2088 " p.d_bo = ptr_f32(base, L->d_bo_offset);\n"
2089 " p.d_w1 = ptr_f32(base, L->d_w1_offset);\n"
2090 " p.d_b1 = ptr_f32(base, L->d_b1_offset);\n"
2091 " p.d_w2 = ptr_f32(base, L->d_w2_offset);\n"
2092 " p.d_b2 = ptr_f32(base, L->d_b2_offset);\n"
2093 " p.d_ln1_out = ptr_f32(base, L->d_ln1_out_offset);\n"
2094 " p.d_q = ptr_f32(base, L->d_q_offset);\n"
2095 " p.d_k = ptr_f32(base, L->d_k_offset);\n"
2096 " p.d_v = ptr_f32(base, L->d_v_offset);\n"
2097 " p.d_scores = ptr_f32(base, L->d_scores_offset);\n"
2098 " p.d_attn_out = ptr_f32(base, L->d_attn_out_offset);\n"
2099 " p.d_proj_tmp = ptr_f32(base, L->d_proj_tmp_offset);\n"
2100 " p.d_residual1 = ptr_f32(base, L->d_residual1_offset);\n"
2101 " p.d_ln2_out = ptr_f32(base, L->d_ln2_out_offset);\n"
2102 " p.d_fc1_out = ptr_f32(base, L->d_fc1_out_offset);\n"
2103 " p.d_swiglu_out = ptr_f32(base, L->d_swiglu_out_offset);\n"
2104 " p.d_mlp_out = ptr_f32(base, L->d_mlp_out_offset);\n"
2106 " const float *src = (layer == m->num_layers - 1)\n"
2108 " : ptr_f32(base, m->layers[layer + 1].d_input_offset);\n"
2109 " memcpy(p.d_output, src, (size_t)T * (size_t)aligned_D * sizeof(float));\n"
2111 " ck_layer_backward_rmsnorm_swiglu(&p);\n"
2115 " TrulyOptimalLayer *L0 = &m->layers[0];\n"
2116 " embedding_backward(tokens,\n"
2118 " ptr_f32(base, L0->d_input_offset),\n"
2119 " ptr_f32(base, m->d_token_emb_offset),\n"
2120 " ptr_f32(base, m->d_pos_emb_offset),\n"
2124 " m->context_window,\n"
2125 " m->rope_theta <= 0.0f);\n"
2128 " /* SGD update is now called separately via optimizer_step() */\n"
2133 "static int parse_int_arg(const char *s, int *out)\n"
2135 " if (!s || !out) return 0;\n"
2136 " char *end = NULL;\n"
2137 " long v = strtol(s, &end, 10);\n"
2138 " if (!end || *end != '\\0') return 0;\n"
2142 "static int parse_float_arg(const char *s, float *out)\n"
2144 " if (!s || !out) return 0;\n"
2145 " char *end = NULL;\n"
2146 " double v = strtod(s, &end);\n"
2147 " if (!end || *end != '\\0') return 0;\n"
2148 " *out = (float)v;\n"
2151 "static void print_usage(const char *prog)\n"
2153 " printf(\"Usage: %%s [options]\\n\", prog);\n"
2154 " printf(\" --dump Print layout summary (layer 0 only)\\n\");\n"
2155 " printf(\" --dump-all Print layout summary for all layers\\n\");\n"
2156 " printf(\" --no-forward Skip forward pass (layout + alloc only)\\n\");\n"
2157 " printf(\" --layers N Override num_layers\\n\");\n"
2158 " printf(\" --embed N Override embed_dim\\n\");\n"
2159 " printf(\" --intermediate N Override intermediate_size\\n\");\n"
2160 " printf(\" --heads N Override num_attention_heads\\n\");\n"
2161 " printf(\" --kv-heads N Override num_kv_heads\\n\");\n"
2162 " printf(\" --vocab N Override vocab_size\\n\");\n"
2163 " printf(\" --ctx N Override context_window\\n\");\n"
2164 " printf(\" --cores N Override num_cores\\n\");\n"
2165 " printf(\" --litmus Run LM head + CE + backward litmus\\n\");\n"
2166 " printf(\" --backward Run backward pass + SGD update (requires --tokens/--targets)\\n\");\n"
2167 " printf(\" --lr F SGD learning rate (default: 1e-3 when --backward)\\n\");\n"
2168 " printf(\" --steps N Training steps (default: 1)\\n\");\n"
2169 " printf(\" --log-steps Print loss per step during training\\n\");\n"
2170 " printf(\" --strict Enable strict parity mode (single-thread + double GEMM)\\n\");\n"
2171 " printf(\" --hidden PATH Load hidden activations [T x aligned_D] f32\\n\");\n"
2172 " printf(\" --weights PATH Load LM head weights [V x aligned_D] f32 (litmus)\\n\");\n"
2173 " printf(\" --targets PATH Load target tokens [T] int32\\n\");\n"
2174 " printf(\" --model-weights PATH Load full model weights (bump format)\\n\");\n"
2175 " printf(\" --tokens PATH Load token IDs [T] int32 and build embeddings\\n\");\n"
2176 " printf(\" --out-logits PATH Write logits [T x V] f32\\n\");\n"
2177 " printf(\" --out-dlogits PATH Write d_logits [T x V] f32\\n\");\n"
2178 " printf(\" --out-dhidden PATH Write d_hidden [T x aligned_D] f32\\n\");\n"
2179 " printf(\" --out-dweights PATH Write d_weights [V x aligned_D] f32\\n\");\n"
2180 " printf(\" --out-loss PATH Write loss (single f32)\\n\");\n"
2181 " printf(\" --out-weights PATH Write model weights (flat, no header)\\n\");\n"
2182 " printf(\" --help Show this help\\n\");\n"
2184 "static int read_floats(const char *path, float *dst, size_t count)\n"
2186 " if (!path || !dst) return -1;\n"
2187 " FILE *f = fopen(path, \"rb\");\n"
2189 " perror(\"fopen\");\n"
2192 " size_t got = fread(dst, sizeof(float), count, f);\n"
2194 " return got == count ? 0 : -1;\n"
2196 "static int read_ints(const char *path, int32_t *dst, size_t count)\n"
2198 " if (!path || !dst) return -1;\n"
2199 " FILE *f = fopen(path, \"rb\");\n"
2201 " perror(\"fopen\");\n"
2204 " size_t got = fread(dst, sizeof(int32_t), count, f);\n"
2206 " return got == count ? 0 : -1;\n"
2208 "static int read_floats_file(FILE *f, float *dst, size_t count)\n"
2210 " if (!f || !dst) return -1;\n"
2211 " size_t got = fread(dst, sizeof(float), count, f);\n"
2212 " return got == count ? 0 : -1;\n"
2214 "static int read_bytes_file(FILE *f, void *dst, size_t bytes)\n"
2216 " if (!f || !dst) return -1;\n"
2217 " size_t got = fread(dst, 1, bytes, f);\n"
2218 " return got == bytes ? 0 : -1;\n"
2220 "static int write_floats_file(FILE *f, const float *src, size_t count)\n"
2222 " if (!f || !src) return -1;\n"
2223 " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2224 " return wrote == count ? 0 : -1;\n"
2226 "static int write_bytes_file(FILE *f, const void *src, size_t bytes)\n"
2228 " if (!f || !src) return -1;\n"
2229 " size_t wrote = fwrite(src, 1, bytes, f);\n"
2230 " return wrote == bytes ? 0 : -1;\n"
2232 "static int read_weight_file(FILE *f, CKDataType dtype, void *dst, size_t n_elements)\n"
2234 " if (!f || !dst) return -1;\n"
2235 " if (dtype == CK_DT_FP32) {\n"
2236 " return read_floats_file(f, (float *)dst, n_elements);\n"
2238 " return read_bytes_file(f, dst, ck_dtype_row_bytes(dtype, n_elements));\n"
2240 "static int write_weight_file(FILE *f, CKDataType dtype, const void *src, size_t n_elements)\n"
2242 " if (!f || !src) return -1;\n"
2243 " if (dtype == CK_DT_FP32) {\n"
2244 " return write_floats_file(f, (const float *)src, n_elements);\n"
2246 " return write_bytes_file(f, src, ck_dtype_row_bytes(dtype, n_elements));\n"
2248 "static int skip_bump_header(FILE *f)\n"
2250 " if (!f) return -1;\n"
2252 " if (fread(magic, 1, 8, f) != 8) return -1;\n"
2253 " if (memcmp(magic, \"BUMPWGT3\", 8) == 0) {\n"
2254 " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2255 " uint32_t dtype_len = 0;\n"
2256 " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) return -1;\n"
2257 " if (fseek(f, (long)dtype_len, SEEK_CUR) != 0) return -1;\n"
2260 " if (memcmp(magic, \"BUMPWGT2\", 8) == 0) {\n"
2261 " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2264 " if (fseek(f, 0, SEEK_SET) != 0) return -1;\n"
2267 "static int load_model_weights(const char *path, TransformerModel *m)\n"
2269 " if (!path || !m || !m->memory_base) return -1;\n"
2270 " FILE *f = fopen(path, \"rb\");\n"
2272 " perror(\"fopen\");\n"
2275 " if (skip_bump_header(f) < 0) {\n"
2279 " uint8_t *base = m->memory_base;\n"
2280 " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2281 " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2282 " if (read_weight_file(f, m->token_emb_dtype, ptr_u8(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2283 " if (read_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2284 " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2286 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2287 " TrulyOptimalLayer *L = &m->layers[layer];\n"
2288 " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2289 " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2290 " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2291 " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2292 " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2293 " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2294 " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2295 " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2297 " if (read_floats_file(f, ptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2298 " if (read_floats_file(f, ptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2299 " if (read_weight_file(f, L->wq_dtype, ptr_u8(base, L->wq_offset), q_w) != 0) goto fail;\n"
2300 " if (read_floats_file(f, ptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2301 " if (read_weight_file(f, L->wk_dtype, ptr_u8(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2302 " if (read_floats_file(f, ptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2303 " if (read_weight_file(f, L->wv_dtype, ptr_u8(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2304 " if (read_floats_file(f, ptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2305 " if (read_weight_file(f, L->wo_dtype, ptr_u8(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2306 " if (read_floats_file(f, ptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2307 " if (read_weight_file(f, L->w1_dtype, ptr_u8(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2308 " if (read_floats_file(f, ptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2309 " if (read_weight_file(f, L->w2_dtype, ptr_u8(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2310 " if (read_floats_file(f, ptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2313 " if (read_floats_file(f, ptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2314 " if (read_floats_file(f, ptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2322 "static int save_model_weights(const char *path, const TransformerModel *m)\n"
2324 " if (!path || !m || !m->memory_base) return -1;\n"
2325 " FILE *f = fopen(path, \"wb\");\n"
2327 " perror(\"fopen\");\n"
2330 " uint8_t *base = m->memory_base;\n"
2331 " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2332 " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2333 " if (write_weight_file(f, m->token_emb_dtype, cptr_void(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2334 " if (write_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2335 " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2337 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2338 " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2339 " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2340 " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2341 " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2342 " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2343 " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2344 " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2345 " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2346 " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2348 " if (write_floats_file(f, cptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2349 " if (write_floats_file(f, cptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2350 " if (write_weight_file(f, L->wq_dtype, cptr_void(base, L->wq_offset), q_w) != 0) goto fail;\n"
2351 " if (write_floats_file(f, cptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2352 " if (write_weight_file(f, L->wk_dtype, cptr_void(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2353 " if (write_floats_file(f, cptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2354 " if (write_weight_file(f, L->wv_dtype, cptr_void(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2355 " if (write_floats_file(f, cptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2356 " if (write_weight_file(f, L->wo_dtype, cptr_void(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2357 " if (write_floats_file(f, cptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2358 " if (write_weight_file(f, L->w1_dtype, cptr_void(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2359 " if (write_floats_file(f, cptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2360 " if (write_weight_file(f, L->w2_dtype, cptr_void(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2361 " if (write_floats_file(f, cptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2364 " if (write_floats_file(f, cptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2365 " if (write_floats_file(f, cptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2373 "static void embed_tokens(const TransformerModel *m, const int32_t *tokens, int token_count)\n"
2375 " if (!m || !m->memory_base || !tokens) return;\n"
2376 " const uint8_t *base = m->memory_base;\n"
2377 " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2378 " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2379 " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2380 " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2381 " int T = m->context_window;\n"
2382 " int D = m->embed_dim;\n"
2383 " int aligned_D = (int)m->aligned_embed_dim;\n"
2384 " for (int t = 0; t < T; ++t) {\n"
2385 " float *dst = out + (size_t)t * aligned_D;\n"
2386 " if (t < token_count) {\n"
2387 " int id = tokens[t];\n"
2388 " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2389 " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2390 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2391 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2392 " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2393 " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2394 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2395 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2396 " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2398 " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2399 " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2401 " if (aligned_D > D) {\n"
2402 " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2404 " if (m->rope_theta <= 0.0f) {\n"
2405 " const float *p = pos + (size_t)t * aligned_D;\n"
2406 " for (int d = 0; d < D; ++d) {\n"
2407 " dst[d] += p[d];\n"
2411 " memset(dst, 0, (size_t)aligned_D * sizeof(float));\n"
2415 "static void embed_token_at(const TransformerModel *m, int32_t token, int t)\n"
2417 " if (!m || !m->memory_base) return;\n"
2418 " if (t < 0 || t >= m->context_window) return;\n"
2419 " const uint8_t *base = m->memory_base;\n"
2420 " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2421 " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2422 " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2423 " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2424 " int D = m->embed_dim;\n"
2425 " int aligned_D = (int)m->aligned_embed_dim;\n"
2426 " int id = (int)token;\n"
2427 " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2428 " float *dst = out + (size_t)t * aligned_D;\n"
2429 " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2430 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2431 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2432 " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2433 " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2434 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2435 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2436 " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2438 " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2439 " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2441 " if (aligned_D > D) {\n"
2442 " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2444 " if (m->rope_theta <= 0.0f) {\n"
2445 " const float *p = pos + (size_t)t * aligned_D;\n"
2446 " for (int d = 0; d < D; ++d) {\n"
2447 " dst[d] += p[d];\n"
2451 "static int write_floats(const char *path, const float *src, size_t count)\n"
2453 " if (!path || !src) return -1;\n"
2454 " FILE *f = fopen(path, \"wb\");\n"
2456 " perror(\"fopen\");\n"
2459 " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2461 " return wrote == count ? 0 : -1;\n"
2463 "static int write_float_scalar(const char *path, float v)\n"
2465 " if (!path) return -1;\n"
2466 " FILE *f = fopen(path, \"wb\");\n"
2468 " perror(\"fopen\");\n"
2471 " size_t wrote = fwrite(&v, sizeof(float), 1, f);\n"
2473 " return wrote == 1 ? 0 : -1;\n"
2475 "static void lm_head_forward(const float *hidden,\n"
2476 " const float *weights,\n"
2478 " int T, int V, int D, int aligned_D)\n"
2480 " for (int t = 0; t < T; ++t) {\n"
2481 " const float *h = hidden + (size_t)t * aligned_D;\n"
2482 " float *out = logits + (size_t)t * V;\n"
2483 " for (int v = 0; v < V; ++v) {\n"
2484 " const float *w = weights + (size_t)v * aligned_D;\n"
2485 " float sum = 0.0f;\n"
2486 " for (int d = 0; d < D; ++d) {\n"
2487 " sum += h[d] * w[d];\n"
2493 "static void softmax_cross_entropy(const float *logits,\n"
2494 " const int32_t *targets,\n"
2496 " float *d_logits,\n"
2497 " float *loss_out)\n"
2499 " double total = 0.0;\n"
2500 " for (int t = 0; t < T; ++t) {\n"
2501 " const float *row = logits + (size_t)t * V;\n"
2502 " float *drow = d_logits + (size_t)t * V;\n"
2503 " int target = targets[t];\n"
2504 " float max_logit = row[0];\n"
2505 " for (int v = 1; v < V; ++v) {\n"
2506 " if (row[v] > max_logit) max_logit = row[v];\n"
2508 " double sum_exp = 0.0;\n"
2509 " for (int v = 0; v < V; ++v) {\n"
2510 " drow[v] = expf(row[v] - max_logit);\n"
2511 " sum_exp += drow[v];\n"
2513 " float inv_sum = 1.0f / (float)sum_exp;\n"
2514 " for (int v = 0; v < V; ++v) {\n"
2515 " drow[v] *= inv_sum;\n"
2517 " double logsum = (double)max_logit + log(sum_exp);\n"
2518 " total += logsum - (double)row[target];\n"
2519 " drow[target] -= 1.0f;\n"
2520 " float scale = 1.0f / (float)T;\n"
2521 " for (int v = 0; v < V; ++v) {\n"
2522 " drow[v] *= scale;\n"
2525 " if (loss_out) {\n"
2526 " *loss_out = (float)(total / (double)T);\n"
2529 "static void lm_head_backward(const float *hidden,\n"
2530 " const float *weights,\n"
2531 " const float *d_logits,\n"
2532 " float *d_hidden,\n"
2533 " float *d_weights,\n"
2534 " int T, int V, int D, int aligned_D)\n"
2536 " size_t dh_count = (size_t)T * aligned_D;\n"
2537 " size_t dw_count = (size_t)V * aligned_D;\n"
2538 " for (size_t i = 0; i < dh_count; ++i) d_hidden[i] = 0.0f;\n"
2539 " for (size_t i = 0; i < dw_count; ++i) d_weights[i] = 0.0f;\n"
2540 " for (int t = 0; t < T; ++t) {\n"
2541 " const float *dlog = d_logits + (size_t)t * V;\n"
2542 " for (int d = 0; d < D; ++d) {\n"
2543 " double sum = 0.0;\n"
2544 " for (int v = 0; v < V; ++v) {\n"
2545 " sum += (double)dlog[v] * (double)weights[(size_t)v * aligned_D + d];\n"
2547 " d_hidden[(size_t)t * aligned_D + d] = (float)sum;\n"
2550 " for (int v = 0; v < V; ++v) {\n"
2551 " float *dw = d_weights + (size_t)v * aligned_D;\n"
2552 " for (int d = 0; d < D; ++d) {\n"
2553 " double sum = 0.0;\n"
2554 " for (int t = 0; t < T; ++t) {\n"
2555 " sum += (double)d_logits[(size_t)t * V + v] * (double)hidden[(size_t)t * aligned_D + d];\n"
2557 " dw[d] = (float)sum;\n"
2563 "static void dump_layer_offsets(const TransformerModel *m, int layer)\n"
2565 " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2566 " printf(\"Layer %%d offsets (bytes):\\n\", layer);\n"
2567 " printf(\" ln1_gamma=%%zu ln2_gamma=%%zu wq=%%zu wk=%%zu wv=%%zu wo=%%zu w1=%%zu w2=%%zu\\n\",\n"
2568 " L->ln1_gamma_offset, L->ln2_gamma_offset, L->wq_offset, L->wk_offset,\n"
2569 " L->wv_offset, L->wo_offset, L->w1_offset, L->w2_offset);\n"
2570 " printf(\" ln1_out=%%zu q=%%zu k=%%zu v=%%zu scores=%%zu attn_out=%%zu\\n\",\n"
2571 " L->ln1_out_offset, L->q_offset, L->k_offset, L->v_offset,\n"
2572 " L->scores_offset, L->attn_out_offset);\n"
2573 " printf(\" proj_tmp=%%zu residual1=%%zu ln2_out=%%zu fc1_out=%%zu swiglu_out=%%zu mlp_out=%%zu output=%%zu\\n\",\n"
2574 " L->proj_tmp_offset, L->residual1_offset, L->ln2_out_offset,\n"
2575 " L->fc1_out_offset, L->swiglu_out_offset, L->mlp_out_offset, L->output_offset);\n"
2577 "static void dump_layout(const TransformerModel *m, int dump_all)\n"
2579 " size_t bytes = m->total_bytes;\n"
2580 " printf(\"Model config:\\n\");\n"
2581 " printf(\" layers=%%d embed=%%d intermediate=%%d heads=%%d kv_heads=%%d\\n\",\n"
2582 " m->num_layers, m->embed_dim, m->intermediate_size, m->num_attention_heads, m->num_kv_heads);\n"
2583 " printf(\" head_dim=%%d vocab=%%d ctx=%%d cores=%%d\\n\",\n"
2584 " m->head_dim, m->vocab_size, m->context_window, m->num_cores);\n"
2585 " printf(\" eps=%%.6g rope_theta=%%.6g\\n\", m->rms_norm_eps, m->rope_theta);\n"
2586 " printf(\"Aligned dims (elements): embed=%%zu head=%%zu ctx=%%zu\\n\",\n"
2587 " m->aligned_embed_dim, m->aligned_head_dim, m->aligned_attn_context_window);\n"
2588 " printf(\"Memory: total_bytes=%%zu\\n\", bytes);\n"
2589 " printf(\"Global offsets (bytes): token=%%zu pos=%%zu embedded=%%zu layers_start=%%zu\\n\",\n"
2590 " m->token_emb_offset, m->pos_emb_offset, m->embedded_input_offset, m->layers_start_offset);\n"
2591 " printf(\"Final offsets (bytes): final_ln_w=%%zu final_ln_b=%%zu final_ln_mean=%%zu final_ln_rstd=%%zu\\n\",\n"
2592 " m->final_ln_weight_offset, m->final_ln_bias_offset,\n"
2593 " m->final_ln_mean_offset, m->final_ln_rstd_offset);\n"
2594 " printf(\"LM/logits offsets (bytes): lm_head=%%zu logits=%%zu\\n\",\n"
2595 " m->lm_head_weight_offset, m->logits_offset);\n"
2596 " if (m->num_layers > 0) {\n"
2597 " dump_layer_offsets(m, 0);\n"
2598 " if (dump_all) {\n"
2599 " for (int i = 1; i < m->num_layers; ++i) {\n"
2600 " dump_layer_offsets(m, i);\n"
2609 "int main(int argc, char **argv)\n"
2612 " int dump_all = 0;\n"
2613 " int no_forward = 0;\n"
2614 " int run_litmus = 0;\n"
2615 " int run_backward = 0;\n"
2616 " const char *litmus_hidden = NULL;\n"
2617 " const char *litmus_weights = NULL;\n"
2618 " const char *litmus_targets = NULL;\n"
2619 " const char *model_weights = NULL;\n"
2620 " const char *tokens_path = NULL;\n"
2621 " const char *out_logits = NULL;\n"
2622 " const char *out_dlogits = NULL;\n"
2623 " const char *out_dhidden = NULL;\n"
2624 " const char *out_dweights = NULL;\n"
2625 " const char *out_loss = NULL;\n"
2626 " const char *out_weights = NULL;\n"
2628 " int log_steps = 0;\n"
2629 " int strict = 0;\n"
2630 " int32_t *tokens = NULL;\n"
2631 " int32_t *targets = NULL;\n"
2632 " TransformerModel m = {0};\n"
2633 " memcpy(m.magic, \"BUMPWGT3\", 8);\n"
2635 " m.model_type = 0;\n"
2636 " m.num_layers = %d;\n"
2637 " m.embed_dim = %d;\n"
2638 " m.intermediate_size = %d;\n"
2639 " m.num_attention_heads = %d;\n"
2640 " m.num_kv_heads = %d;\n"
2641 " m.vocab_size = %d;\n"
2642 " m.context_window = %d;\n"
2643 " m.rms_norm_eps = %.9g;\n"
2644 " m.rope_theta = %.9g;\n"
2645 " m.num_cores = 1;\n"
2646 " m.task_type = TASK_LM;\n"
2647 " m.optimizer = OPTIMIZER_SGD;\n"
2648 " m.learning_rate = 0.0f;\n"
2649 " for (int i = 1; i < argc; ++i) {\n"
2650 " if (strcmp(argv[i], \"--dump\") == 0) {\n"
2654 " if (strcmp(argv[i], \"--dump-all\") == 0) {\n"
2659 " if (strcmp(argv[i], \"--no-forward\") == 0) {\n"
2660 " no_forward = 1;\n"
2663 " if (strcmp(argv[i], \"--strict\") == 0) {\n"
2667 " if (strcmp(argv[i], \"--litmus\") == 0) {\n"
2668 " run_litmus = 1;\n"
2671 " if (strcmp(argv[i], \"--backward\") == 0) {\n"
2672 " run_backward = 1;\n"
2675 " if (strcmp(argv[i], \"--lr\") == 0 && i + 1 < argc) {\n"
2676 " parse_float_arg(argv[++i], &m.learning_rate);\n"
2679 " if (strcmp(argv[i], \"--help\") == 0) {\n"
2680 " print_usage(argv[0]);\n"
2683 " if (strcmp(argv[i], \"--hidden\") == 0 && i + 1 < argc) {\n"
2684 " litmus_hidden = argv[++i];\n"
2687 " if (strcmp(argv[i], \"--weights\") == 0 && i + 1 < argc) {\n"
2688 " litmus_weights = argv[++i];\n"
2691 " if (strcmp(argv[i], \"--targets\") == 0 && i + 1 < argc) {\n"
2692 " litmus_targets = argv[++i];\n"
2695 " if (strcmp(argv[i], \"--model-weights\") == 0 && i + 1 < argc) {\n"
2696 " model_weights = argv[++i];\n"
2699 " if (strcmp(argv[i], \"--tokens\") == 0 && i + 1 < argc) {\n"
2700 " tokens_path = argv[++i];\n"
2703 " if (strcmp(argv[i], \"--out-logits\") == 0 && i + 1 < argc) {\n"
2704 " out_logits = argv[++i];\n"
2707 " if (strcmp(argv[i], \"--out-dlogits\") == 0 && i + 1 < argc) {\n"
2708 " out_dlogits = argv[++i];\n"
2711 " if (strcmp(argv[i], \"--out-dhidden\") == 0 && i + 1 < argc) {\n"
2712 " out_dhidden = argv[++i];\n"
2715 " if (strcmp(argv[i], \"--out-dweights\") == 0 && i + 1 < argc) {\n"
2716 " out_dweights = argv[++i];\n"
2719 " if (strcmp(argv[i], \"--out-loss\") == 0 && i + 1 < argc) {\n"
2720 " out_loss = argv[++i];\n"
2723 " if (strcmp(argv[i], \"--out-weights\") == 0 && i + 1 < argc) {\n"
2724 " out_weights = argv[++i];\n"
2727 " if (strcmp(argv[i], \"--steps\") == 0 && i + 1 < argc) {\n"
2728 " parse_int_arg(argv[++i], &steps);\n"
2731 " if (strcmp(argv[i], \"--log-steps\") == 0) {\n"
2735 " if (strcmp(argv[i], \"--layers\") == 0 && i + 1 < argc) {\n"
2736 " parse_int_arg(argv[++i], &m.num_layers);\n"
2739 " if (strcmp(argv[i], \"--embed\") == 0 && i + 1 < argc) {\n"
2740 " parse_int_arg(argv[++i], &m.embed_dim);\n"
2743 " if (strcmp(argv[i], \"--intermediate\") == 0 && i + 1 < argc) {\n"
2744 " parse_int_arg(argv[++i], &m.intermediate_size);\n"
2747 " if (strcmp(argv[i], \"--heads\") == 0 && i + 1 < argc) {\n"
2748 " parse_int_arg(argv[++i], &m.num_attention_heads);\n"
2751 " if (strcmp(argv[i], \"--kv-heads\") == 0 && i + 1 < argc) {\n"
2752 " parse_int_arg(argv[++i], &m.num_kv_heads);\n"
2755 " if (strcmp(argv[i], \"--vocab\") == 0 && i + 1 < argc) {\n"
2756 " parse_int_arg(argv[++i], &m.vocab_size);\n"
2759 " if (strcmp(argv[i], \"--ctx\") == 0 && i + 1 < argc) {\n"
2760 " parse_int_arg(argv[++i], &m.context_window);\n"
2763 " if (strcmp(argv[i], \"--cores\") == 0 && i + 1 < argc) {\n"
2764 " parse_int_arg(argv[++i], &m.num_cores);\n"
2767 " fprintf(stderr, \"Unknown or invalid arg: %%s\\n\", argv[i]);\n"
2768 " print_usage(argv[0]);\n"
2772 " ck_set_strict_parity(1);\n"
2774 " if (run_backward && m.learning_rate == 0.0f) {\n"
2775 " m.learning_rate = 1e-3f;\n"
2777 " m.training_enabled = run_backward;\n"
2778 " m.weight_dtype = CK_DT_FP32;\n"
2780 " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
2782 " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
2783 " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
2784 " m.weight_dtype = CK_DT_Q4_K;\n"
2785 " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
2786 " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
2787 " m.weight_dtype = CK_DT_Q6_K;\n"
2791 " init_weight_dtypes_uniform(&m, m.weight_dtype);\n"
2792 " refresh_weight_flags(&m);\n"
2793 " if (model_weights) {\n"
2794 " int dtype_rc = load_weight_dtypes(model_weights, &m);\n"
2795 " if (dtype_rc < 0) {\n"
2796 " fprintf(stderr, \"failed to read weight dtype table\\n\");\n"
2800 " if (m.training_enabled && m.weights_quantized) {\n"
2801 " fprintf(stderr, \"Quantized weights are inference-only; disable training\\n\");\n"
2804 " if (layout_model(&m) != 0) {\n"
2805 " fprintf(stderr, \"layout_model failed\\n\");\n"
2808 " if (model_weights) {\n"
2809 " if (load_model_weights(model_weights, &m) != 0) {\n"
2810 " fprintf(stderr, \"failed to load model weights\\n\");\n"
2814 " if (tokens_path) {\n"
2815 " int T = m.context_window;\n"
2816 " tokens = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2818 " fprintf(stderr, \"failed to alloc tokens\\n\");\n"
2821 " if (read_ints(tokens_path, tokens, (size_t)T) != 0) {\n"
2822 " fprintf(stderr, \"failed to read tokens\\n\");\n"
2827 " if (!run_backward) {\n"
2828 " embed_tokens(&m, tokens, T);\n"
2833 " if (run_backward) {\n"
2834 " if (!litmus_targets) {\n"
2835 " fprintf(stderr, \"backward requires --targets\\n\");\n"
2838 " int T = m.context_window;\n"
2839 " targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2840 " if (!targets) {\n"
2841 " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2844 " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2845 " fprintf(stderr, \"failed to read targets\\n\");\n"
2847 " targets = NULL;\n"
2852 " dump_layout(&m, dump_all);\n"
2854 " if (run_litmus) {\n"
2855 " if (!litmus_hidden || !litmus_weights || !litmus_targets) {\n"
2856 " fprintf(stderr, \"litmus requires --hidden, --weights, and --targets\\n\");\n"
2859 " int T = m.context_window;\n"
2860 " int V = m.vocab_size;\n"
2861 " int D = m.embed_dim;\n"
2862 " int aligned_D = (int)m.aligned_embed_dim;\n"
2863 " float *hidden = ptr_f32(m.memory_base, m.final_output_offset);\n"
2864 " float *weights = ptr_f32(m.memory_base, m.lm_head_weight_offset);\n"
2865 " float *logits = ptr_f32(m.memory_base, m.logits_offset);\n"
2866 " if (read_floats(litmus_hidden, hidden, (size_t)T * aligned_D) != 0) {\n"
2867 " fprintf(stderr, \"failed to read hidden\\n\");\n"
2870 " if (read_floats(litmus_weights, weights, (size_t)V * aligned_D) != 0) {\n"
2871 " fprintf(stderr, \"failed to read weights\\n\");\n"
2874 " int32_t *targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2875 " if (!targets) {\n"
2876 " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2879 " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2880 " fprintf(stderr, \"failed to read targets\\n\");\n"
2884 " float *d_logits = (float *)calloc((size_t)T * V, sizeof(float));\n"
2885 " float *d_hidden = (float *)calloc((size_t)T * aligned_D, sizeof(float));\n"
2886 " float *d_weights = (float *)calloc((size_t)V * aligned_D, sizeof(float));\n"
2887 " if (!d_logits || !d_hidden || !d_weights) {\n"
2888 " fprintf(stderr, \"failed to alloc grads\\n\");\n"
2890 " free(d_logits);\n"
2891 " free(d_hidden);\n"
2892 " free(d_weights);\n"
2895 " lm_head_forward(hidden, weights, logits, T, V, D, aligned_D);\n"
2896 " float loss = 0.0f;\n"
2897 " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2898 " lm_head_backward(hidden, weights, d_logits, d_hidden, d_weights, T, V, D, aligned_D);\n"
2899 " if (out_logits) write_floats(out_logits, logits, (size_t)T * V);\n"
2900 " if (out_dlogits) write_floats(out_dlogits, d_logits, (size_t)T * V);\n"
2901 " if (out_dhidden) write_floats(out_dhidden, d_hidden, (size_t)T * aligned_D);\n"
2902 " if (out_dweights) write_floats(out_dweights, d_weights, (size_t)V * aligned_D);\n"
2903 " if (out_loss) write_float_scalar(out_loss, loss);\n"
2904 " if (!out_loss) printf(\"loss=%%.6f\\n\", loss);\n"
2906 " free(d_logits);\n"
2907 " free(d_hidden);\n"
2908 " free(d_weights);\n"
2909 " ck_huge_free(m.memory_base, m.total_bytes);\n"
2910 " free(m.layers);\n"
2913 " // TODO: load weights into m.memory_base using the offsets above.\n"
2914 " // TODO: write token/pos embeddings into embedded_input_offset.\n"
2915 " if (!run_backward) {\n"
2916 " if (!no_forward) {\n"
2917 " run_model_forward(&m);\n"
2920 " if (!tokens || !targets) {\n"
2921 " fprintf(stderr, \"backward requires --tokens and --targets\\n\");\n"
2924 " if (steps < 1) steps = 1;\n"
2925 " float loss = 0.0f;\n"
2926 " for (int step = 0; step < steps; ++step) {\n"
2927 " embed_tokens(&m, tokens, m.context_window);\n"
2928 " run_model_forward(&m);\n"
2929 " if (run_model_backward(&m, tokens, targets, &loss) != 0) {\n"
2930 " fprintf(stderr, \"backward failed\\n\");\n"
2933 " if (log_steps) {\n"
2934 " printf(\"step %%d loss=%%.6f\\n\", step, loss);\n"
2937 " if (out_loss) {\n"
2938 " write_float_scalar(out_loss, loss);\n"
2941 " if (out_logits) {\n"
2942 " write_floats(out_logits, ptr_f32(m.memory_base, m.logits_offset),\n"
2943 " (size_t)m.context_window * (size_t)m.vocab_size);\n"
2945 " if (out_weights) {\n"
2946 " if (save_model_weights(out_weights, &m) != 0) {\n"
2947 " fprintf(stderr, \"failed to save model weights\\n\");\n"
2951 " ck_huge_free(m.memory_base, m.total_bytes);\n"
2952 " free(m.layers);\n"
static void emit_shape_expr(FILE *out, const CKDimToken *shape)
static int ck_buffer_uses_weight_dtype(const CKBufferSpec *spec)
static int emit_runtime_preamble(FILE *out)
int ck_codegen_emit_runtime(const CKIRGraph *forward, const char *path, CKEmitMode mode)
static const CKKernelSpec * ck_find_kernel_spec(const char *name)
static const CKBufferSpec * ck_find_buffer_spec(const char *name)
static void emit_global_offset_fields(FILE *out)
static const char * ck_weight_dtype_expr(const CKBufferSpec *spec)
static const char * op_name(CKOpType op)
static void emit_offset_field(FILE *out, const char *name)
static int emit_kernel_manifest(const CKIRGraph *forward, const char *runtime_path)
static const char * ck_first_layer_buffer_name(void)
static int ck_plan_step_enabled(const CKPlanStep *step, const CKIRGraph *cfg)
static void emit_library_api(FILE *out, const CKIRGraph *forward)
static void emit_layer_allocations(FILE *out)
static void emit_sgd_update(FILE *out)
static void emit_dim_expr(FILE *out, CKDimKind dim)
static int emit_plan_sources(FILE *f, const CKPlanStep *plan, size_t plan_count, const CKIRGraph *cfg, const char **seen, size_t *seen_count, size_t seen_cap)
static void emit_model_struct(FILE *out)
static void emit_global_allocations(FILE *out)
static int ck_buffer_should_alloc(const CKBufferSpec *spec)
static void emit_zero_grad(FILE *out)
static int emit_unique_source(FILE *f, const char *path, const char **seen, size_t *seen_count, size_t seen_cap)
void ck_codegen_c_skeleton(const CKIRGraph *forward, const CKIRGraph *backward, FILE *out)
static void emit_bump_bytes_assignment(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
static void emit_layer_offsets_struct(FILE *out)
static void emit_training_conditional_assignment(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
static void emit_bump_bytes_assignment_weight_dtype(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape, const char *dtype_expr)
static void emit_global_aliases_to_layer(FILE *out)
#define CKERNEL_MAX_KERNEL_SOURCES
const CKPlanStep ck_decoder_forward_plan[]
@ CK_DIM_ALIGNED_INTERMEDIATE
const size_t ck_decoder_backward_plan_count
const size_t ck_decoder_forward_plan_count
const CKPlanStep ck_decoder_backward_plan[]
const CKKernelSpec ck_kernel_specs[]
const CKBufferSpec ck_decoder_buffers[]
const size_t ck_kernel_spec_count
const size_t ck_decoder_buffer_count
int ck_ir_validate_supported(const CKIRGraph *graph)