← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_codegen.h File Reference
#include "ckernel_ir.h"
#include <stdio.h>

Go to the source code of this file.

Enumerations

enum  CKEmitMode { CK_EMIT_STANDALONE = 0 , CK_EMIT_LIBRARY = 1 }
 

Functions

void ck_codegen_c_skeleton (const CKIRGraph *forward, const CKIRGraph *backward, FILE *out)
 
int ck_codegen_emit_runtime (const CKIRGraph *forward, const char *path, CKEmitMode mode)
 

Enumeration Type Documentation

◆ CKEmitMode

enum CKEmitMode

Code generation output mode.

Enumerator
CK_EMIT_STANDALONE 
CK_EMIT_LIBRARY 

Definition at line 11 of file ckernel_codegen.h.

11  {
12  CK_EMIT_STANDALONE = 0, /* Emit with main() for standalone executable */
13  CK_EMIT_LIBRARY = 1, /* Emit as library with API functions, no main() */
14 } CKEmitMode;
CKEmitMode
@ CK_EMIT_STANDALONE
@ CK_EMIT_LIBRARY

Function Documentation

◆ ck_codegen_c_skeleton()

void ck_codegen_c_skeleton ( const CKIRGraph forward,
const CKIRGraph backward,
FILE *  out 
)

Emit a C skeleton for forward + backward execution based on the IR.

This does not yet generate full pointer arithmetic or memory planning. It is intended as a starting point that:

  • Defines a model config / runtime context
  • Shows a per-layer forward loop over IR nodes
  • Sketches a backward loop over the backward IR

Definition at line 613 of file ckernel_codegen.c.

616 {
617  if (!forward || !out) {
618  return;
619  }
620 
621  fprintf(out,
622  "/* Auto-generated skeleton from CKIRGraph.\n"
623  " * This file sketches the structure of the forward and backward\n"
624  " * execution for a decoder-only transformer. It is NOT yet a\n"
625  " * complete, runnable implementation. You can use it as a\n"
626  " * starting point to wire buffers, kernel calls, and memory layout.\n"
627  " */\n\n");
628 
629  fprintf(out, "#include \"ckernel_engine.h\"\n");
630  fprintf(out, "#include \"ckernel_model.h\"\n");
631  fprintf(out, "#include \"ckernel_alloc.h\"\n\n");
632 
633  /* Forward function */
634  fprintf(out,
635  "void run_decoder_forward(TransformerModel *model /*, inputs, etc. */)\n"
636  "{\n"
637  " for (int layer = 0; layer < model->cfg.num_layers; ++layer) {\n"
638  " /* Forward pass for layer */\n");
639 
640  int nodes_per_layer = 0;
641  if (forward->num_nodes > 0) {
642  int l0 = forward->nodes[0].id.layer;
643  for (int i = 0; i < forward->num_nodes; ++i) {
644  if (forward->nodes[i].id.layer != l0) {
645  break;
646  }
647  nodes_per_layer++;
648  }
649  }
650 
651  if (nodes_per_layer <= 0) {
652  nodes_per_layer = forward->num_nodes;
653  }
654 
655  fprintf(out, " /* This layer has %d IR nodes */\n", nodes_per_layer);
656 
657  for (int i = 0; i < nodes_per_layer; ++i) {
658  const CKIRNode *n = &forward->nodes[i];
659  fprintf(out, " // L%%d: %s\n", op_name(n->op));
660  fprintf(out,
661  " // outputs: [");
662  for (int o = 0; o < n->n_outputs; ++o) {
663  if (o > 0) fprintf(out, ", ");
664  fprintf(out, "L%%d:N%d:%d", n->id.node, o);
665  }
666  fprintf(out, "]\n");
667  fprintf(out, " // inputs : [");
668  for (int j = 0; j < n->n_inputs; ++j) {
669  const CKInputRef *inp = &n->inputs[j];
670  if (j > 0) fprintf(out, ", ");
671  if (inp->producer.node == 0xFFFFu) {
672  fprintf(out, "IN");
673  } else {
674  fprintf(out, "L%%d:N%u:%u",
675  (unsigned)inp->producer.node,
676  (unsigned)inp->out_index);
677  }
678  }
679  fprintf(out, "]\n");
680  fprintf(out,
681  " // TODO: bind buffers/weights and call %s kernel here\n\n",
682  op_name(n->op));
683  }
684 
685  fprintf(out,
686  " } /* end for layer */\n"
687  "}\n\n");
688 
689  /* Backward skeleton */
690  if (backward && backward->nodes && backward->num_nodes > 0) {
691  fprintf(out,
692  "void run_decoder_backward(TransformerModel *model /*, grads, etc. */)\n"
693  "{\n"
694  " for (int layer = model->cfg.num_layers - 1; layer >= 0; --layer) {\n"
695  " /* Backward pass for layer */\n");
696 
697  int bwd_per_layer = 0;
698  int l0 = backward->nodes[0].id.layer;
699  for (int i = 0; i < backward->num_nodes; ++i) {
700  if (backward->nodes[i].id.layer != l0) break;
701  bwd_per_layer++;
702  }
703  if (bwd_per_layer <= 0) bwd_per_layer = backward->num_nodes;
704 
705  fprintf(out, " /* This layer has %d backward IR nodes */\n", bwd_per_layer);
706 
707  for (int i = 0; i < bwd_per_layer; ++i) {
708  const CKIRNode *n = &backward->nodes[i];
709  fprintf(out, " // L%%d: %s\n", op_name(n->op));
710  fprintf(out,
711  " // TODO: wire gradient tensors and call %s kernel here\n\n",
712  op_name(n->op));
713  }
714 
715  fprintf(out,
716  " } /* end for layer */\n"
717  "}\n\n");
718  }
719 
720  fprintf(out,
721  "int main(int argc, char **argv)\n"
722  "{\n"
723  " (void)argc; (void)argv;\n"
724  " TransformerModel model = {0};\n"
725  " model.cfg.num_layers = %d;\n"
726  " model.cfg.hidden_size = %d;\n"
727  " model.cfg.intermediate_size = %d;\n"
728  " model.cfg.num_heads = %d;\n"
729  " model.cfg.num_kv_heads = %d;\n"
730  " model.cfg.vocab_size = %d;\n"
731  " model.cfg.context_window = %d;\n"
732  " model.cfg.rms_norm_eps = %.9g;\n"
733  " model.cfg.rope_theta = %.9g;\n"
734  " layout_transformer_from_ir(&model, NULL); /* TODO: pass IR if needed */\n"
735  " size_t bytes = model.total_bytes;\n"
736  " model.memory_base = (uint8_t *)ck_huge_alloc(bytes);\n"
737  " if (!model.memory_base) {\n"
738  " fprintf(stderr, \"Failed to allocate %%zu bytes for model\\n\", bytes);\n"
739  " return 1;\n"
740  " }\n"
741  " // TODO: load weights into model.memory_base based on offsets\n"
742  " run_decoder_forward(&model);\n"
743  " // TODO: run_decoder_backward(&model) when training\n"
744  " ck_huge_free(model.memory_base, bytes);\n"
745  " return 0;\n"
746  "}\n",
747  forward->config.num_layers,
748  forward->config.hidden_size,
749  forward->config.intermediate_size,
750  forward->config.num_heads,
751  forward->config.num_kv_heads,
752  forward->config.vocab_size,
753  forward->config.context_window,
754  forward->config.rms_norm_eps,
755  forward->config.rope_theta);
756 }
static const char * op_name(CKOpType op)
CKIRNode * nodes
Definition: ckernel_ir.h:75
int num_nodes
Definition: ckernel_ir.h:74
CKModelConfig config
Definition: ckernel_ir.h:73
CKOpType op
Definition: ckernel_ir.h:65
CKInputRef inputs[4]
Definition: ckernel_ir.h:66
CKKernelId id
Definition: ckernel_ir.h:64
uint8_t n_inputs
Definition: ckernel_ir.h:67
uint8_t n_outputs
Definition: ckernel_ir.h:68
CKKernelId producer
Definition: ckernel_ir.h:59
uint8_t out_index
Definition: ckernel_ir.h:60
uint16_t node
Definition: ckernel_ir.h:55
uint16_t layer
Definition: ckernel_ir.h:54
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37
float rms_norm_eps
Definition: ckernel_ir.h:31
float rope_theta
Definition: ckernel_ir.h:32

References CKIRGraph::config, CKModelConfig::context_window, CKModelConfig::hidden_size, CKIRNode::id, CKIRNode::inputs, CKModelConfig::intermediate_size, CKKernelId::layer, CKIRNode::n_inputs, CKIRNode::n_outputs, CKKernelId::node, CKIRGraph::nodes, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKIRGraph::num_nodes, CKIRNode::op, op_name(), CKInputRef::out_index, CKInputRef::producer, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, and CKModelConfig::vocab_size.

Referenced by main().

◆ ck_codegen_emit_runtime()

int ck_codegen_emit_runtime ( const CKIRGraph forward,
const char *  path,
CKEmitMode  mode 
)

Emit a C runtime file that stitches kernels for the given forward IR.

Parameters
forwardThe forward IR graph
pathOutput file path
modeCK_EMIT_STANDALONE for executable with main(), CK_EMIT_LIBRARY for shared object with API functions

Returns 0 on success, non-zero on failure.

Definition at line 1441 of file ckernel_codegen.c.

1442 {
1443  if (!forward || !path) {
1444  return -1;
1445  }
1446  if (ck_ir_validate_supported(forward) != 0) {
1447  return -1;
1448  }
1449 
1450  FILE *out = fopen(path, "wb");
1451  if (!out) {
1452  fprintf(stderr, "ck_codegen_emit_runtime: failed to open %s: %s\n",
1453  path, strerror(errno));
1454  return -1;
1455  }
1456 
1457  if (emit_runtime_preamble(out) != 0) {
1458  fclose(out);
1459  return -1;
1460  }
1461 
1462  fprintf(out,
1463  "typedef enum {\n"
1464  " TASK_LM = 0,\n"
1465  " TASK_SEQ_CLS = 1\n"
1466  "} TaskType;\n\n"
1467  "typedef enum {\n"
1468  " OPTIMIZER_SGD = 0,\n"
1469  " OPTIMIZER_ADAM = 1\n"
1470  "} OptimizerType;\n\n"
1471  "typedef struct {\n"
1472  " size_t total_gradient_floats;\n"
1473  "} GradientStorage;\n\n");
1474 
1476  emit_model_struct(out);
1477 
1478  fprintf(out,
1479  "static int ensure_layers_allocated(TransformerModel *m)\n"
1480  "{\n"
1481  " if (!m) return -1;\n"
1482  " if (!m->layers && m->num_layers > 0) {\n"
1483  " m->layers = (TrulyOptimalLayer *)calloc((size_t)m->num_layers, sizeof(TrulyOptimalLayer));\n"
1484  " if (!m->layers) return -1;\n"
1485  " }\n"
1486  " return 0;\n"
1487  "}\n\n"
1488  "static void init_weight_dtypes_uniform(TransformerModel *m, CKDataType dt)\n"
1489  "{\n"
1490  " if (!m) return;\n"
1491  " m->token_emb_dtype = dt;\n"
1492  " m->lm_head_weight_dtype = dt;\n"
1493  " m->pos_emb_dtype = CK_DT_FP32;\n"
1494  " if (ensure_layers_allocated(m) != 0) return;\n"
1495  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1496  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1497  " L->wq_dtype = dt;\n"
1498  " L->wk_dtype = dt;\n"
1499  " L->wv_dtype = dt;\n"
1500  " L->wo_dtype = dt;\n"
1501  " L->w1_dtype = dt;\n"
1502  " L->w2_dtype = dt;\n"
1503  " }\n"
1504  "}\n\n"
1505  "static void refresh_weight_flags(TransformerModel *m)\n"
1506  "{\n"
1507  " if (!m) return;\n"
1508  " CKDataType base = m->token_emb_dtype;\n"
1509  " int mixed = 0;\n"
1510  " int quant = ck_dtype_is_quantized(base);\n"
1511  " if (m->lm_head_weight_dtype != base) mixed = 1;\n"
1512  " if (ck_dtype_is_quantized(m->lm_head_weight_dtype)) quant = 1;\n"
1513  " if (m->layers) {\n"
1514  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1515  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1516  " if (L->wq_dtype != base || L->wk_dtype != base || L->wv_dtype != base ||\n"
1517  " L->wo_dtype != base || L->w1_dtype != base || L->w2_dtype != base) {\n"
1518  " mixed = 1;\n"
1519  " }\n"
1520  " if (ck_dtype_is_quantized(L->wq_dtype) || ck_dtype_is_quantized(L->wk_dtype) ||\n"
1521  " ck_dtype_is_quantized(L->wv_dtype) || ck_dtype_is_quantized(L->wo_dtype) ||\n"
1522  " ck_dtype_is_quantized(L->w1_dtype) || ck_dtype_is_quantized(L->w2_dtype)) {\n"
1523  " quant = 1;\n"
1524  " }\n"
1525  " }\n"
1526  " }\n"
1527  " m->weights_mixed = mixed ? true : false;\n"
1528  " m->weights_quantized = quant ? true : false;\n"
1529  " if (!mixed) {\n"
1530  " m->weight_dtype = base;\n"
1531  " }\n"
1532  "}\n\n"
1533  "static int load_weight_dtypes(const char *path, TransformerModel *m)\n"
1534  "{\n"
1535  " if (!path || !m) return -1;\n"
1536  " FILE *f = fopen(path, \"rb\");\n"
1537  " if (!f) return -1;\n"
1538  " char magic[8];\n"
1539  " if (fread(magic, 1, 8, f) != 8) {\n"
1540  " fclose(f);\n"
1541  " return -1;\n"
1542  " }\n"
1543  " if (memcmp(magic, \"BUMPWGT3\", 8) != 0) {\n"
1544  " fclose(f);\n"
1545  " return 0;\n"
1546  " }\n"
1547  " uint32_t version = 0;\n"
1548  " if (fread(&version, sizeof(uint32_t), 1, f) != 1) {\n"
1549  " fclose(f);\n"
1550  " return -1;\n"
1551  " }\n"
1552  " if (version < 3) {\n"
1553  " fclose(f);\n"
1554  " return -1;\n"
1555  " }\n"
1556  " if (fseek(f, 128, SEEK_SET) != 0) {\n"
1557  " fclose(f);\n"
1558  " return -1;\n"
1559  " }\n"
1560  " uint32_t dtype_len = 0;\n"
1561  " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) {\n"
1562  " fclose(f);\n"
1563  " return -1;\n"
1564  " }\n"
1565  " if (dtype_len == 0) {\n"
1566  " fclose(f);\n"
1567  " return -1;\n"
1568  " }\n"
1569  " uint8_t *dtype_buf = (uint8_t *)malloc(dtype_len);\n"
1570  " if (!dtype_buf) {\n"
1571  " fclose(f);\n"
1572  " return -1;\n"
1573  " }\n"
1574  " if (fread(dtype_buf, 1, dtype_len, f) != dtype_len) {\n"
1575  " free(dtype_buf);\n"
1576  " fclose(f);\n"
1577  " return -1;\n"
1578  " }\n"
1579  " fclose(f);\n"
1580  "\n"
1581  " size_t expected = (size_t)m->num_layers * 14u + 4u;\n"
1582  " if (dtype_len != expected) {\n"
1583  " free(dtype_buf);\n"
1584  " return -1;\n"
1585  " }\n"
1586  " if (ensure_layers_allocated(m) != 0) {\n"
1587  " free(dtype_buf);\n"
1588  " return -1;\n"
1589  " }\n"
1590  "\n"
1591  " size_t idx = 0;\n"
1592  " CKDataType token_dt = (CKDataType)dtype_buf[idx++];\n"
1593  " CKDataType pos_dt = (CKDataType)dtype_buf[idx++];\n"
1594  " if (pos_dt != CK_DT_FP32) {\n"
1595  " free(dtype_buf);\n"
1596  " return -1;\n"
1597  " }\n"
1598  " if (token_dt != CK_DT_FP32 && token_dt != CK_DT_Q4_K && token_dt != CK_DT_Q6_K) {\n"
1599  " free(dtype_buf);\n"
1600  " return -1;\n"
1601  " }\n"
1602  " m->token_emb_dtype = token_dt;\n"
1603  " m->lm_head_weight_dtype = token_dt;\n"
1604  " m->pos_emb_dtype = pos_dt;\n"
1605  "\n"
1606  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1607  " CKDataType ln1_dt = (CKDataType)dtype_buf[idx++];\n"
1608  " CKDataType ln2_dt = (CKDataType)dtype_buf[idx++];\n"
1609  " CKDataType wq_dt = (CKDataType)dtype_buf[idx++];\n"
1610  " CKDataType bq_dt = (CKDataType)dtype_buf[idx++];\n"
1611  " CKDataType wk_dt = (CKDataType)dtype_buf[idx++];\n"
1612  " CKDataType bk_dt = (CKDataType)dtype_buf[idx++];\n"
1613  " CKDataType wv_dt = (CKDataType)dtype_buf[idx++];\n"
1614  " CKDataType bv_dt = (CKDataType)dtype_buf[idx++];\n"
1615  " CKDataType wo_dt = (CKDataType)dtype_buf[idx++];\n"
1616  " CKDataType bo_dt = (CKDataType)dtype_buf[idx++];\n"
1617  " CKDataType w1_dt = (CKDataType)dtype_buf[idx++];\n"
1618  " CKDataType b1_dt = (CKDataType)dtype_buf[idx++];\n"
1619  " CKDataType w2_dt = (CKDataType)dtype_buf[idx++];\n"
1620  " CKDataType b2_dt = (CKDataType)dtype_buf[idx++];\n"
1621  "\n"
1622  " if (ln1_dt != CK_DT_FP32 || ln2_dt != CK_DT_FP32 ||\n"
1623  " bq_dt != CK_DT_FP32 || bk_dt != CK_DT_FP32 ||\n"
1624  " bv_dt != CK_DT_FP32 || bo_dt != CK_DT_FP32 ||\n"
1625  " b1_dt != CK_DT_FP32 || b2_dt != CK_DT_FP32) {\n"
1626  " free(dtype_buf);\n"
1627  " return -1;\n"
1628  " }\n"
1629  " if ((wq_dt != CK_DT_FP32 && wq_dt != CK_DT_Q4_K && wq_dt != CK_DT_Q6_K) ||\n"
1630  " (wk_dt != CK_DT_FP32 && wk_dt != CK_DT_Q4_K && wk_dt != CK_DT_Q6_K) ||\n"
1631  " (wv_dt != CK_DT_FP32 && wv_dt != CK_DT_Q4_K && wv_dt != CK_DT_Q6_K) ||\n"
1632  " (wo_dt != CK_DT_FP32 && wo_dt != CK_DT_Q4_K && wo_dt != CK_DT_Q6_K) ||\n"
1633  " (w1_dt != CK_DT_FP32 && w1_dt != CK_DT_Q4_K && w1_dt != CK_DT_Q6_K) ||\n"
1634  " (w2_dt != CK_DT_FP32 && w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K)) {\n"
1635  " free(dtype_buf);\n"
1636  " return -1;\n"
1637  " }\n"
1638  "\n"
1639  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1640  " L->wq_dtype = wq_dt;\n"
1641  " L->wk_dtype = wk_dt;\n"
1642  " L->wv_dtype = wv_dt;\n"
1643  " L->wo_dtype = wo_dt;\n"
1644  " L->w1_dtype = w1_dt;\n"
1645  " L->w2_dtype = w2_dt;\n"
1646  " }\n"
1647  "\n"
1648  " CKDataType final_norm_dt = (CKDataType)dtype_buf[idx++];\n"
1649  " CKDataType final_bias_dt = (CKDataType)dtype_buf[idx++];\n"
1650  " free(dtype_buf);\n"
1651  " if (final_norm_dt != CK_DT_FP32 || final_bias_dt != CK_DT_FP32) {\n"
1652  " return -1;\n"
1653  " }\n"
1654  "\n"
1655  " refresh_weight_flags(m);\n"
1656  " return 1;\n"
1657  "}\n\n"
1658  "\n"
1659  "static int layout_model(TransformerModel *m)\n"
1660  "{\n"
1661  " if (!m) return -1;\n"
1662  " if (m->num_attention_heads <= 0 || m->embed_dim <= 0) return -1;\n"
1663  " if (m->num_kv_heads <= 0) m->num_kv_heads = m->num_attention_heads;\n"
1664  " if (m->num_attention_heads %% m->num_kv_heads != 0) return -1;\n"
1665  " if (m->context_window <= 0) m->context_window = 1;\n"
1666  " if (m->vocab_size <= 0) m->vocab_size = 1;\n"
1667  " if (m->intermediate_size <= 0) return -1;\n"
1668  " m->head_dim = m->embed_dim / m->num_attention_heads;\n"
1669  " if (m->rms_norm_eps <= 0.0f) m->rms_norm_eps = 1e-5f;\n"
1670  " if (m->rope_theta < 0.0f) m->rope_theta = 0.0f;\n"
1671  " if (m->rope_theta > 0.0f && (m->head_dim %% 2 != 0)) return -1;\n"
1672  " if (m->elem_bytes == 0) m->elem_bytes = sizeof(float);\n"
1673  " size_t elem_bytes = m->elem_bytes;\n"
1674  " m->aligned_embed_dim = align_up_elems((size_t)m->embed_dim, elem_bytes, CACHELINE_BYTES);\n"
1675  " m->aligned_head_dim = align_up_elems((size_t)m->head_dim, elem_bytes, CACHELINE_BYTES);\n"
1676  " m->aligned_attn_context_window = align_up_elems((size_t)m->context_window, elem_bytes, CACHELINE_BYTES);\n"
1677  " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, elem_bytes, CACHELINE_BYTES);\n"
1678  " if (ensure_layers_allocated(m) != 0) return -1;\n"
1679  " if (m->weights_quantized) {\n"
1680  " /* K-quant weights require K dimension to be a multiple of 256. */\n"
1681  " if ((m->aligned_embed_dim %% 256) != 0) return -1;\n"
1682  " if ((aligned_intermediate_dim %% 256) != 0) return -1;\n"
1683  " int wo_quant = 0;\n"
1684  " if (m->layers) {\n"
1685  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1686  " if (ck_dtype_is_quantized(m->layers[layer].wo_dtype)) {\n"
1687  " wo_quant = 1;\n"
1688  " break;\n"
1689  " }\n"
1690  " }\n"
1691  " }\n"
1692  " if (wo_quant && (size_t)m->num_attention_heads * m->aligned_head_dim != m->aligned_embed_dim) return -1;\n"
1693  " }\n"
1694  "\n"
1695  " if (m->num_cores <= 0) m->num_cores = 1;\n"
1696  " m->tokens_per_core = (m->context_window + m->num_cores - 1) / m->num_cores;\n"
1697  "\n"
1698  " size_t off = 0;\n");
1700  fprintf(out,
1701  " m->layers_start_offset = off;\n"
1702  "\n"
1703  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1704  " TrulyOptimalLayer *L = &m->layers[layer];\n");
1706  fprintf(out,
1707  " }\n"
1708  "\n");
1709  {
1710  const char *stride_field = ck_first_layer_buffer_name();
1711  fprintf(out,
1712  " if (m->num_layers > 1) {\n"
1713  " m->layer_stride = m->layers[1].%s_offset - m->layers[0].%s_offset;\n"
1714  " } else {\n"
1715  " m->layer_stride = 0;\n"
1716  " }\n",
1717  stride_field, stride_field);
1718  }
1720  fprintf(out,
1721  " m->total_bytes = align_up_bytes(off, CACHELINE_BYTES);\n"
1722  " m->memory_base = (uint8_t *)ck_huge_alloc(m->total_bytes);\n"
1723  " if (!m->memory_base) return -1;\n"
1724  " if (m->rope_theta > 0.0f) {\n"
1725  " rope_precompute_cache(ptr_f32(m->memory_base, m->rope_cos_cache_offset),\n"
1726  " ptr_f32(m->memory_base, m->rope_sin_cache_offset),\n"
1727  " m->context_window,\n"
1728  " m->head_dim,\n"
1729  " m->rope_theta);\n"
1730  " }\n"
1731  " return 0;\n"
1732  "}\n\n");
1733 
1734  fprintf(out,
1735  "static void lm_head_forward(const float *hidden,\n"
1736  " const float *weights,\n"
1737  " float *logits,\n"
1738  " int T, int V, int D, int aligned_D);\n"
1739  "static void lm_head_backward(const float *hidden,\n"
1740  " const float *weights,\n"
1741  " const float *d_logits,\n"
1742  " float *d_hidden,\n"
1743  " float *d_weights,\n"
1744  " int T, int V, int D, int aligned_D);\n"
1745  "static void softmax_cross_entropy(const float *logits,\n"
1746  " const int32_t *targets,\n"
1747  " int T, int V,\n"
1748  " float *d_logits,\n"
1749  " float *loss_out);\n\n");
1750 
1751  fprintf(out,
1752  "static void run_model_forward(TransformerModel *m)\n"
1753  "{\n"
1754  " uint8_t *base = m->memory_base;\n"
1755  " float *current = ptr_f32(base, m->embedded_input_offset);\n"
1756  " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
1757  " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
1758  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1759  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1760  " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1761  " CKLayerForwardParamsQ4K p = {0};\n"
1762  " p.tokens = T;\n"
1763  " p.embed_dim = m->embed_dim;\n"
1764  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1765  " p.num_heads = m->num_attention_heads;\n"
1766  " p.num_kv_heads = m->num_kv_heads;\n"
1767  " p.head_dim = m->head_dim;\n"
1768  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1769  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1770  " p.intermediate_dim = m->intermediate_size;\n"
1771  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1772  " p.eps = m->rms_norm_eps;\n"
1773  " p.rope_pos_offset = 0;\n"
1774  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1775  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1776  " p.input = current;\n"
1777  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1778  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1779  " p.wq = cptr_void(base, L->wq_offset);\n"
1780  " p.bq = cptr_f32(base, L->bq_offset);\n"
1781  " p.wk = cptr_void(base, L->wk_offset);\n"
1782  " p.bk = cptr_f32(base, L->bk_offset);\n"
1783  " p.wv = cptr_void(base, L->wv_offset);\n"
1784  " p.bv = cptr_f32(base, L->bv_offset);\n"
1785  " p.wo = cptr_void(base, L->wo_offset);\n"
1786  " p.bo = cptr_f32(base, L->bo_offset);\n"
1787  " p.w1 = cptr_void(base, L->w1_offset);\n"
1788  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1789  " p.w2 = cptr_void(base, L->w2_offset);\n"
1790  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1791  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1792  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1793  " p.q = ptr_f32(base, L->q_offset);\n"
1794  " p.k = ptr_f32(base, L->k_offset);\n"
1795  " p.v = ptr_f32(base, L->v_offset);\n"
1796  " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1797  " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1798  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1799  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1800  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1801  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1802  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1803  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1804  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1805  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1806  " p.output = ptr_f32(base, L->output_offset);\n"
1807  " ck_layer_forward_rmsnorm_swiglu_q4_k(&p);\n"
1808  " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1809  " kv_cache_repack_head_major_inplace(p.k,\n"
1810  " p.num_kv_heads,\n"
1811  " T,\n"
1812  " m->kv_cache_capacity,\n"
1813  " p.aligned_head_dim);\n"
1814  " kv_cache_repack_head_major_inplace(p.v,\n"
1815  " p.num_kv_heads,\n"
1816  " T,\n"
1817  " m->kv_cache_capacity,\n"
1818  " p.aligned_head_dim);\n"
1819  " }\n"
1820  " current = p.output;\n"
1821  " } else if (m->weights_quantized) {\n"
1822  " CKLayerForwardParamsQ4K p = {0};\n"
1823  " p.tokens = T;\n"
1824  " p.embed_dim = m->embed_dim;\n"
1825  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1826  " p.num_heads = m->num_attention_heads;\n"
1827  " p.num_kv_heads = m->num_kv_heads;\n"
1828  " p.head_dim = m->head_dim;\n"
1829  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1830  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1831  " p.intermediate_dim = m->intermediate_size;\n"
1832  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1833  " p.eps = m->rms_norm_eps;\n"
1834  " p.rope_pos_offset = 0;\n"
1835  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1836  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1837  " p.input = current;\n"
1838  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1839  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1840  " p.wq = cptr_void(base, L->wq_offset);\n"
1841  " p.bq = cptr_f32(base, L->bq_offset);\n"
1842  " p.wk = cptr_void(base, L->wk_offset);\n"
1843  " p.bk = cptr_f32(base, L->bk_offset);\n"
1844  " p.wv = cptr_void(base, L->wv_offset);\n"
1845  " p.bv = cptr_f32(base, L->bv_offset);\n"
1846  " p.wo = cptr_void(base, L->wo_offset);\n"
1847  " p.bo = cptr_f32(base, L->bo_offset);\n"
1848  " p.w1 = cptr_void(base, L->w1_offset);\n"
1849  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1850  " p.w2 = cptr_void(base, L->w2_offset);\n"
1851  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1852  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1853  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1854  " p.q = ptr_f32(base, L->q_offset);\n"
1855  " p.k = ptr_f32(base, L->k_offset);\n"
1856  " p.v = ptr_f32(base, L->v_offset);\n"
1857  " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1858  " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1859  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1860  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1861  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1862  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1863  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1864  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1865  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1866  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1867  " p.output = ptr_f32(base, L->output_offset);\n"
1868  " p.wq_dtype = L->wq_dtype;\n"
1869  " p.wk_dtype = L->wk_dtype;\n"
1870  " p.wv_dtype = L->wv_dtype;\n"
1871  " p.wo_dtype = L->wo_dtype;\n"
1872  " p.w1_dtype = L->w1_dtype;\n"
1873  " p.w2_dtype = L->w2_dtype;\n"
1874  " ck_layer_forward_rmsnorm_swiglu_quant(&p);\n"
1875  " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1876  " kv_cache_repack_head_major_inplace(p.k,\n"
1877  " p.num_kv_heads,\n"
1878  " T,\n"
1879  " m->kv_cache_capacity,\n"
1880  " p.aligned_head_dim);\n"
1881  " kv_cache_repack_head_major_inplace(p.v,\n"
1882  " p.num_kv_heads,\n"
1883  " T,\n"
1884  " m->kv_cache_capacity,\n"
1885  " p.aligned_head_dim);\n"
1886  " }\n"
1887  " current = p.output;\n"
1888  " } else {\n"
1889  " CKLayerForwardParams p = {0};\n"
1890  " p.tokens = T;\n"
1891  " p.embed_dim = m->embed_dim;\n"
1892  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1893  " p.num_heads = m->num_attention_heads;\n"
1894  " p.num_kv_heads = m->num_kv_heads;\n"
1895  " p.head_dim = m->head_dim;\n"
1896  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1897  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1898  " p.intermediate_dim = m->intermediate_size;\n"
1899  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1900  " p.eps = m->rms_norm_eps;\n"
1901  " p.rope_pos_offset = 0;\n"
1902  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1903  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1904  " p.input = current;\n"
1905  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1906  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1907  " p.wq = cptr_f32(base, L->wq_offset);\n"
1908  " p.bq = cptr_f32(base, L->bq_offset);\n"
1909  " p.wk = cptr_f32(base, L->wk_offset);\n"
1910  " p.bk = cptr_f32(base, L->bk_offset);\n"
1911  " p.wv = cptr_f32(base, L->wv_offset);\n"
1912  " p.bv = cptr_f32(base, L->bv_offset);\n"
1913  " p.wo = cptr_f32(base, L->wo_offset);\n"
1914  " p.bo = cptr_f32(base, L->bo_offset);\n"
1915  " p.w1 = cptr_f32(base, L->w1_offset);\n"
1916  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1917  " p.w2 = cptr_f32(base, L->w2_offset);\n"
1918  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1919  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1920  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1921  " p.q = ptr_f32(base, L->q_offset);\n"
1922  " p.k = ptr_f32(base, L->k_offset);\n"
1923  " p.v = ptr_f32(base, L->v_offset);\n"
1924  " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1925  " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1926  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1927  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1928  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1929  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1930  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1931  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1932  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1933  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1934  " p.output = ptr_f32(base, L->output_offset);\n"
1935  " ck_layer_forward_rmsnorm_swiglu(&p);\n"
1936  " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1937  " kv_cache_repack_head_major_inplace(p.k,\n"
1938  " p.num_kv_heads,\n"
1939  " T,\n"
1940  " m->kv_cache_capacity,\n"
1941  " p.aligned_head_dim);\n"
1942  " kv_cache_repack_head_major_inplace(p.v,\n"
1943  " p.num_kv_heads,\n"
1944  " T,\n"
1945  " m->kv_cache_capacity,\n"
1946  " p.aligned_head_dim);\n"
1947  " }\n"
1948  " current = p.output;\n"
1949  " }\n"
1950  " }\n"
1951  " float *final_out = ptr_f32(base, m->final_output_offset);\n"
1952  " rmsnorm_forward(current,\n"
1953  " cptr_f32(base, m->final_ln_weight_offset),\n"
1954  " final_out,\n"
1955  " ptr_f32(base, m->final_ln_rstd_offset),\n"
1956  " T,\n"
1957  " m->embed_dim,\n"
1958  " (int)m->aligned_embed_dim,\n"
1959  " m->rms_norm_eps);\n"
1960  " if (m->vocab_size > 0) {\n"
1961  " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1962  " gemm_nt_q4_k(final_out,\n"
1963  " cptr_void(base, m->lm_head_weight_offset),\n"
1964  " NULL,\n"
1965  " ptr_f32(base, m->logits_offset),\n"
1966  " T,\n"
1967  " m->vocab_size,\n"
1968  " (int)m->aligned_embed_dim);\n"
1969  " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1970  " gemm_nt_q6_k(final_out,\n"
1971  " cptr_void(base, m->lm_head_weight_offset),\n"
1972  " NULL,\n"
1973  " ptr_f32(base, m->logits_offset),\n"
1974  " T,\n"
1975  " m->vocab_size,\n"
1976  " (int)m->aligned_embed_dim);\n"
1977  " } else {\n"
1978  " lm_head_forward(final_out,\n"
1979  " cptr_f32(base, m->lm_head_weight_offset),\n"
1980  " ptr_f32(base, m->logits_offset),\n"
1981  " T,\n"
1982  " m->vocab_size,\n"
1983  " m->embed_dim,\n"
1984  " (int)m->aligned_embed_dim);\n"
1985  " }\n"
1986  " }\n"
1987  "}\n\n");
1988 
1989  emit_zero_grad(out);
1990  emit_sgd_update(out);
1991 
1992  fprintf(out,
1993  "static int run_model_backward(TransformerModel *m,\n"
1994  " const int32_t *tokens,\n"
1995  " const int32_t *targets,\n"
1996  " float *loss_out)\n"
1997  "{\n"
1998  " if (!m || !m->training_enabled) return 0;\n"
1999  " if (!tokens || !targets) return -1;\n"
2000  " if (m->num_layers <= 0) return -1;\n"
2001  " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
2002  " int V = m->vocab_size;\n"
2003  " int D = m->embed_dim;\n"
2004  " int aligned_D = (int)m->aligned_embed_dim;\n"
2005  " uint8_t *base = m->memory_base;\n"
2006  "\n"
2007  " zero_grad(m);\n"
2008  "\n"
2009  " float *final_out = ptr_f32(base, m->final_output_offset);\n"
2010  " float *logits = ptr_f32(base, m->logits_offset);\n"
2011  " float *d_logits = ptr_f32(base, m->d_logits_offset);\n"
2012  " float *d_final_out = ptr_f32(base, m->d_final_output_offset);\n"
2013  " float *d_final_in = ptr_f32(base, m->d_final_input_offset);\n"
2014  "\n"
2015  " float loss = 0.0f;\n"
2016  " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2017  " if (loss_out) {\n"
2018  " *loss_out = loss;\n"
2019  " }\n"
2020  " lm_head_backward(final_out,\n"
2021  " cptr_f32(base, m->lm_head_weight_offset),\n"
2022  " d_logits,\n"
2023  " d_final_out,\n"
2024  " ptr_f32(base, m->d_token_emb_offset),\n"
2025  " T, V, D, aligned_D);\n"
2026  " rmsnorm_backward(d_final_out,\n"
2027  " ptr_f32(base, m->layers[m->num_layers - 1].output_offset),\n"
2028  " cptr_f32(base, m->final_ln_weight_offset),\n"
2029  " ptr_f32(base, m->final_ln_rstd_offset),\n"
2030  " d_final_in,\n"
2031  " ptr_f32(base, m->d_final_ln_weight_offset),\n"
2032  " T, D, aligned_D);\n"
2033  "\n"
2034  " for (int layer = m->num_layers - 1; layer >= 0; --layer) {\n"
2035  " TrulyOptimalLayer *L = &m->layers[layer];\n"
2036  " CKLayerBackwardParams p = {0};\n"
2037  " p.tokens = T;\n"
2038  " p.embed_dim = m->embed_dim;\n"
2039  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
2040  " p.num_heads = m->num_attention_heads;\n"
2041  " p.num_kv_heads = m->num_kv_heads;\n"
2042  " p.head_dim = m->head_dim;\n"
2043  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
2044  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
2045  " p.intermediate_dim = m->intermediate_size;\n"
2046  " p.aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2047  " p.eps = m->rms_norm_eps;\n"
2048  " p.rope_pos_offset = 0;\n"
2049  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
2050  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
2051  " p.input = (layer == 0) ? ptr_f32(base, m->embedded_input_offset)\n"
2052  " : ptr_f32(base, m->layers[layer - 1].output_offset);\n"
2053  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
2054  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
2055  " p.ln1_out = cptr_f32(base, L->ln1_out_offset);\n"
2056  " p.ln1_rstd = cptr_f32(base, L->ln1_rstd_offset);\n"
2057  " p.ln2_out = cptr_f32(base, L->ln2_out_offset);\n"
2058  " p.ln2_rstd = cptr_f32(base, L->ln2_rstd_offset);\n"
2059  " p.wq = cptr_f32(base, L->wq_offset);\n"
2060  " p.bq = cptr_f32(base, L->bq_offset);\n"
2061  " p.wk = cptr_f32(base, L->wk_offset);\n"
2062  " p.bk = cptr_f32(base, L->bk_offset);\n"
2063  " p.wv = cptr_f32(base, L->wv_offset);\n"
2064  " p.bv = cptr_f32(base, L->bv_offset);\n"
2065  " p.wo = cptr_f32(base, L->wo_offset);\n"
2066  " p.bo = cptr_f32(base, L->bo_offset);\n"
2067  " p.w1 = cptr_f32(base, L->w1_offset);\n"
2068  " p.b1 = cptr_f32(base, L->b1_offset);\n"
2069  " p.w2 = cptr_f32(base, L->w2_offset);\n"
2070  " p.b2 = cptr_f32(base, L->b2_offset);\n"
2071  " p.q = cptr_f32(base, L->q_offset);\n"
2072  " p.k = cptr_f32(base, L->k_offset);\n"
2073  " p.v = cptr_f32(base, L->v_offset);\n"
2074  " p.scores = L->scores_offset ? cptr_f32(base, L->scores_offset) : NULL;\n"
2075  " p.attn_out = cptr_f32(base, L->attn_out_offset);\n"
2076  " p.residual1 = cptr_f32(base, L->residual1_offset);\n"
2077  " p.fc1_out = cptr_f32(base, L->fc1_out_offset);\n"
2078  " p.swiglu_out = cptr_f32(base, L->swiglu_out_offset);\n"
2079  " p.d_output = ptr_f32(base, L->d_output_offset);\n"
2080  " p.d_input = ptr_f32(base, L->d_input_offset);\n"
2081  " p.d_ln1_gamma = ptr_f32(base, L->d_ln1_gamma_offset);\n"
2082  " p.d_ln2_gamma = ptr_f32(base, L->d_ln2_gamma_offset);\n"
2083  " p.d_wq = ptr_f32(base, L->d_wq_offset);\n"
2084  " p.d_bq = ptr_f32(base, L->d_bq_offset);\n"
2085  " p.d_wk = ptr_f32(base, L->d_wk_offset);\n"
2086  " p.d_bk = ptr_f32(base, L->d_bk_offset);\n"
2087  " p.d_wv = ptr_f32(base, L->d_wv_offset);\n"
2088  " p.d_bv = ptr_f32(base, L->d_bv_offset);\n"
2089  " p.d_wo = ptr_f32(base, L->d_wo_offset);\n"
2090  " p.d_bo = ptr_f32(base, L->d_bo_offset);\n"
2091  " p.d_w1 = ptr_f32(base, L->d_w1_offset);\n"
2092  " p.d_b1 = ptr_f32(base, L->d_b1_offset);\n"
2093  " p.d_w2 = ptr_f32(base, L->d_w2_offset);\n"
2094  " p.d_b2 = ptr_f32(base, L->d_b2_offset);\n"
2095  " p.d_ln1_out = ptr_f32(base, L->d_ln1_out_offset);\n"
2096  " p.d_q = ptr_f32(base, L->d_q_offset);\n"
2097  " p.d_k = ptr_f32(base, L->d_k_offset);\n"
2098  " p.d_v = ptr_f32(base, L->d_v_offset);\n"
2099  " p.d_scores = ptr_f32(base, L->d_scores_offset);\n"
2100  " p.d_attn_out = ptr_f32(base, L->d_attn_out_offset);\n"
2101  " p.d_proj_tmp = ptr_f32(base, L->d_proj_tmp_offset);\n"
2102  " p.d_residual1 = ptr_f32(base, L->d_residual1_offset);\n"
2103  " p.d_ln2_out = ptr_f32(base, L->d_ln2_out_offset);\n"
2104  " p.d_fc1_out = ptr_f32(base, L->d_fc1_out_offset);\n"
2105  " p.d_swiglu_out = ptr_f32(base, L->d_swiglu_out_offset);\n"
2106  " p.d_mlp_out = ptr_f32(base, L->d_mlp_out_offset);\n"
2107  "\n"
2108  " const float *src = (layer == m->num_layers - 1)\n"
2109  " ? d_final_in\n"
2110  " : ptr_f32(base, m->layers[layer + 1].d_input_offset);\n"
2111  " memcpy(p.d_output, src, (size_t)T * (size_t)aligned_D * sizeof(float));\n"
2112  "\n"
2113  " ck_layer_backward_rmsnorm_swiglu(&p);\n"
2114  " }\n"
2115  "\n"
2116  " {\n"
2117  " TrulyOptimalLayer *L0 = &m->layers[0];\n"
2118  " embedding_backward(tokens,\n"
2119  " T,\n"
2120  " ptr_f32(base, L0->d_input_offset),\n"
2121  " ptr_f32(base, m->d_token_emb_offset),\n"
2122  " ptr_f32(base, m->d_pos_emb_offset),\n"
2123  " m->vocab_size,\n"
2124  " m->embed_dim,\n"
2125  " aligned_D,\n"
2126  " m->context_window,\n"
2127  " m->rope_theta <= 0.0f);\n"
2128  " }\n"
2129  "\n"
2130  " /* SGD update is now called separately via optimizer_step() */\n"
2131  " return 0;\n"
2132  "}\n\n");
2133 
2134  fprintf(out,
2135  "static int parse_int_arg(const char *s, int *out)\n"
2136  "{\n"
2137  " if (!s || !out) return 0;\n"
2138  " char *end = NULL;\n"
2139  " long v = strtol(s, &end, 10);\n"
2140  " if (!end || *end != '\\0') return 0;\n"
2141  " *out = (int)v;\n"
2142  " return 1;\n"
2143  "}\n\n"
2144  "static int parse_float_arg(const char *s, float *out)\n"
2145  "{\n"
2146  " if (!s || !out) return 0;\n"
2147  " char *end = NULL;\n"
2148  " double v = strtod(s, &end);\n"
2149  " if (!end || *end != '\\0') return 0;\n"
2150  " *out = (float)v;\n"
2151  " return 1;\n"
2152  "}\n\n"
2153  "static void print_usage(const char *prog)\n"
2154  "{\n"
2155  " printf(\"Usage: %%s [options]\\n\", prog);\n"
2156  " printf(\" --dump Print layout summary (layer 0 only)\\n\");\n"
2157  " printf(\" --dump-all Print layout summary for all layers\\n\");\n"
2158  " printf(\" --no-forward Skip forward pass (layout + alloc only)\\n\");\n"
2159  " printf(\" --layers N Override num_layers\\n\");\n"
2160  " printf(\" --embed N Override embed_dim\\n\");\n"
2161  " printf(\" --intermediate N Override intermediate_size\\n\");\n"
2162  " printf(\" --heads N Override num_attention_heads\\n\");\n"
2163  " printf(\" --kv-heads N Override num_kv_heads\\n\");\n"
2164  " printf(\" --vocab N Override vocab_size\\n\");\n"
2165  " printf(\" --ctx N Override context_window\\n\");\n"
2166  " printf(\" --cores N Override num_cores\\n\");\n"
2167  " printf(\" --litmus Run LM head + CE + backward litmus\\n\");\n"
2168  " printf(\" --backward Run backward pass + SGD update (requires --tokens/--targets)\\n\");\n"
2169  " printf(\" --lr F SGD learning rate (default: 1e-3 when --backward)\\n\");\n"
2170  " printf(\" --steps N Training steps (default: 1)\\n\");\n"
2171  " printf(\" --log-steps Print loss per step during training\\n\");\n"
2172  " printf(\" --strict Enable strict parity mode (single-thread + double GEMM)\\n\");\n"
2173  " printf(\" --hidden PATH Load hidden activations [T x aligned_D] f32\\n\");\n"
2174  " printf(\" --weights PATH Load LM head weights [V x aligned_D] f32 (litmus)\\n\");\n"
2175  " printf(\" --targets PATH Load target tokens [T] int32\\n\");\n"
2176  " printf(\" --model-weights PATH Load full model weights (bump format)\\n\");\n"
2177  " printf(\" --tokens PATH Load token IDs [T] int32 and build embeddings\\n\");\n"
2178  " printf(\" --out-logits PATH Write logits [T x V] f32\\n\");\n"
2179  " printf(\" --out-dlogits PATH Write d_logits [T x V] f32\\n\");\n"
2180  " printf(\" --out-dhidden PATH Write d_hidden [T x aligned_D] f32\\n\");\n"
2181  " printf(\" --out-dweights PATH Write d_weights [V x aligned_D] f32\\n\");\n"
2182  " printf(\" --out-loss PATH Write loss (single f32)\\n\");\n"
2183  " printf(\" --out-weights PATH Write model weights (flat, no header)\\n\");\n"
2184  " printf(\" --help Show this help\\n\");\n"
2185  "}\n\n"
2186  "static int read_floats(const char *path, float *dst, size_t count)\n"
2187  "{\n"
2188  " if (!path || !dst) return -1;\n"
2189  " FILE *f = fopen(path, \"rb\");\n"
2190  " if (!f) {\n"
2191  " perror(\"fopen\");\n"
2192  " return -1;\n"
2193  " }\n"
2194  " size_t got = fread(dst, sizeof(float), count, f);\n"
2195  " fclose(f);\n"
2196  " return got == count ? 0 : -1;\n"
2197  "}\n\n"
2198  "static int read_ints(const char *path, int32_t *dst, size_t count)\n"
2199  "{\n"
2200  " if (!path || !dst) return -1;\n"
2201  " FILE *f = fopen(path, \"rb\");\n"
2202  " if (!f) {\n"
2203  " perror(\"fopen\");\n"
2204  " return -1;\n"
2205  " }\n"
2206  " size_t got = fread(dst, sizeof(int32_t), count, f);\n"
2207  " fclose(f);\n"
2208  " return got == count ? 0 : -1;\n"
2209  "}\n\n"
2210  "static int read_floats_file(FILE *f, float *dst, size_t count)\n"
2211  "{\n"
2212  " if (!f || !dst) return -1;\n"
2213  " size_t got = fread(dst, sizeof(float), count, f);\n"
2214  " return got == count ? 0 : -1;\n"
2215  "}\n\n"
2216  "static int read_bytes_file(FILE *f, void *dst, size_t bytes)\n"
2217  "{\n"
2218  " if (!f || !dst) return -1;\n"
2219  " size_t got = fread(dst, 1, bytes, f);\n"
2220  " return got == bytes ? 0 : -1;\n"
2221  "}\n\n"
2222  "static int write_floats_file(FILE *f, const float *src, size_t count)\n"
2223  "{\n"
2224  " if (!f || !src) return -1;\n"
2225  " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2226  " return wrote == count ? 0 : -1;\n"
2227  "}\n\n"
2228  "static int write_bytes_file(FILE *f, const void *src, size_t bytes)\n"
2229  "{\n"
2230  " if (!f || !src) return -1;\n"
2231  " size_t wrote = fwrite(src, 1, bytes, f);\n"
2232  " return wrote == bytes ? 0 : -1;\n"
2233  "}\n\n"
2234  "static int read_weight_file(FILE *f, CKDataType dtype, void *dst, size_t n_elements)\n"
2235  "{\n"
2236  " if (!f || !dst) return -1;\n"
2237  " if (dtype == CK_DT_FP32) {\n"
2238  " return read_floats_file(f, (float *)dst, n_elements);\n"
2239  " }\n"
2240  " return read_bytes_file(f, dst, ck_dtype_row_bytes(dtype, n_elements));\n"
2241  "}\n\n"
2242  "static int write_weight_file(FILE *f, CKDataType dtype, const void *src, size_t n_elements)\n"
2243  "{\n"
2244  " if (!f || !src) return -1;\n"
2245  " if (dtype == CK_DT_FP32) {\n"
2246  " return write_floats_file(f, (const float *)src, n_elements);\n"
2247  " }\n"
2248  " return write_bytes_file(f, src, ck_dtype_row_bytes(dtype, n_elements));\n"
2249  "}\n\n"
2250  "static int skip_bump_header(FILE *f)\n"
2251  "{\n"
2252  " if (!f) return -1;\n"
2253  " char magic[8];\n"
2254  " if (fread(magic, 1, 8, f) != 8) return -1;\n"
2255  " if (memcmp(magic, \"BUMPWGT3\", 8) == 0) {\n"
2256  " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2257  " uint32_t dtype_len = 0;\n"
2258  " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) return -1;\n"
2259  " if (fseek(f, (long)dtype_len, SEEK_CUR) != 0) return -1;\n"
2260  " return 1;\n"
2261  " }\n"
2262  " if (memcmp(magic, \"BUMPWGT2\", 8) == 0) {\n"
2263  " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2264  " return 1;\n"
2265  " }\n"
2266  " if (fseek(f, 0, SEEK_SET) != 0) return -1;\n"
2267  " return 0;\n"
2268  "}\n\n"
2269  "static int load_model_weights(const char *path, TransformerModel *m)\n"
2270  "{\n"
2271  " if (!path || !m || !m->memory_base) return -1;\n"
2272  " FILE *f = fopen(path, \"rb\");\n"
2273  " if (!f) {\n"
2274  " perror(\"fopen\");\n"
2275  " return -1;\n"
2276  " }\n"
2277  " if (skip_bump_header(f) < 0) {\n"
2278  " fclose(f);\n"
2279  " return -1;\n"
2280  " }\n"
2281  " uint8_t *base = m->memory_base;\n"
2282  " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2283  " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2284  " if (read_weight_file(f, m->token_emb_dtype, ptr_u8(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2285  " if (read_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2286  " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2287  "\n"
2288  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2289  " TrulyOptimalLayer *L = &m->layers[layer];\n"
2290  " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2291  " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2292  " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2293  " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2294  " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2295  " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2296  " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2297  " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2298  "\n"
2299  " if (read_floats_file(f, ptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2300  " if (read_floats_file(f, ptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2301  " if (read_weight_file(f, L->wq_dtype, ptr_u8(base, L->wq_offset), q_w) != 0) goto fail;\n"
2302  " if (read_floats_file(f, ptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2303  " if (read_weight_file(f, L->wk_dtype, ptr_u8(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2304  " if (read_floats_file(f, ptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2305  " if (read_weight_file(f, L->wv_dtype, ptr_u8(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2306  " if (read_floats_file(f, ptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2307  " if (read_weight_file(f, L->wo_dtype, ptr_u8(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2308  " if (read_floats_file(f, ptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2309  " if (read_weight_file(f, L->w1_dtype, ptr_u8(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2310  " if (read_floats_file(f, ptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2311  " if (read_weight_file(f, L->w2_dtype, ptr_u8(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2312  " if (read_floats_file(f, ptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2313  " }\n"
2314  "\n"
2315  " if (read_floats_file(f, ptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2316  " if (read_floats_file(f, ptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2317  "\n"
2318  " fclose(f);\n"
2319  " return 0;\n"
2320  "fail:\n"
2321  " fclose(f);\n"
2322  " return -1;\n"
2323  "}\n\n"
2324  "static int save_model_weights(const char *path, const TransformerModel *m)\n"
2325  "{\n"
2326  " if (!path || !m || !m->memory_base) return -1;\n"
2327  " FILE *f = fopen(path, \"wb\");\n"
2328  " if (!f) {\n"
2329  " perror(\"fopen\");\n"
2330  " return -1;\n"
2331  " }\n"
2332  " uint8_t *base = m->memory_base;\n"
2333  " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2334  " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2335  " if (write_weight_file(f, m->token_emb_dtype, cptr_void(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2336  " if (write_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2337  " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2338  "\n"
2339  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2340  " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2341  " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2342  " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2343  " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2344  " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2345  " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2346  " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2347  " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2348  " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2349  "\n"
2350  " if (write_floats_file(f, cptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2351  " if (write_floats_file(f, cptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2352  " if (write_weight_file(f, L->wq_dtype, cptr_void(base, L->wq_offset), q_w) != 0) goto fail;\n"
2353  " if (write_floats_file(f, cptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2354  " if (write_weight_file(f, L->wk_dtype, cptr_void(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2355  " if (write_floats_file(f, cptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2356  " if (write_weight_file(f, L->wv_dtype, cptr_void(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2357  " if (write_floats_file(f, cptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2358  " if (write_weight_file(f, L->wo_dtype, cptr_void(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2359  " if (write_floats_file(f, cptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2360  " if (write_weight_file(f, L->w1_dtype, cptr_void(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2361  " if (write_floats_file(f, cptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2362  " if (write_weight_file(f, L->w2_dtype, cptr_void(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2363  " if (write_floats_file(f, cptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2364  " }\n"
2365  "\n"
2366  " if (write_floats_file(f, cptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2367  " if (write_floats_file(f, cptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2368  "\n"
2369  " fclose(f);\n"
2370  " return 0;\n"
2371  "fail:\n"
2372  " fclose(f);\n"
2373  " return -1;\n"
2374  "}\n\n"
2375  "static void embed_tokens(const TransformerModel *m, const int32_t *tokens, int token_count)\n"
2376  "{\n"
2377  " if (!m || !m->memory_base || !tokens) return;\n"
2378  " const uint8_t *base = m->memory_base;\n"
2379  " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2380  " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2381  " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2382  " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2383  " int T = m->context_window;\n"
2384  " int D = m->embed_dim;\n"
2385  " int aligned_D = (int)m->aligned_embed_dim;\n"
2386  " for (int t = 0; t < T; ++t) {\n"
2387  " float *dst = out + (size_t)t * aligned_D;\n"
2388  " if (t < token_count) {\n"
2389  " int id = tokens[t];\n"
2390  " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2391  " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2392  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2393  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2394  " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2395  " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2396  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2397  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2398  " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2399  " } else {\n"
2400  " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2401  " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2402  " }\n"
2403  " if (aligned_D > D) {\n"
2404  " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2405  " }\n"
2406  " if (m->rope_theta <= 0.0f) {\n"
2407  " const float *p = pos + (size_t)t * aligned_D;\n"
2408  " for (int d = 0; d < D; ++d) {\n"
2409  " dst[d] += p[d];\n"
2410  " }\n"
2411  " }\n"
2412  " } else {\n"
2413  " memset(dst, 0, (size_t)aligned_D * sizeof(float));\n"
2414  " }\n"
2415  " }\n"
2416  "}\n\n"
2417  "static void embed_token_at(const TransformerModel *m, int32_t token, int t)\n"
2418  "{\n"
2419  " if (!m || !m->memory_base) return;\n"
2420  " if (t < 0 || t >= m->context_window) return;\n"
2421  " const uint8_t *base = m->memory_base;\n"
2422  " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2423  " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2424  " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2425  " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2426  " int D = m->embed_dim;\n"
2427  " int aligned_D = (int)m->aligned_embed_dim;\n"
2428  " int id = (int)token;\n"
2429  " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2430  " float *dst = out + (size_t)t * aligned_D;\n"
2431  " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2432  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2433  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2434  " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2435  " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2436  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2437  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2438  " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2439  " } else {\n"
2440  " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2441  " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2442  " }\n"
2443  " if (aligned_D > D) {\n"
2444  " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2445  " }\n"
2446  " if (m->rope_theta <= 0.0f) {\n"
2447  " const float *p = pos + (size_t)t * aligned_D;\n"
2448  " for (int d = 0; d < D; ++d) {\n"
2449  " dst[d] += p[d];\n"
2450  " }\n"
2451  " }\n"
2452  "}\n\n"
2453  "static int write_floats(const char *path, const float *src, size_t count)\n"
2454  "{\n"
2455  " if (!path || !src) return -1;\n"
2456  " FILE *f = fopen(path, \"wb\");\n"
2457  " if (!f) {\n"
2458  " perror(\"fopen\");\n"
2459  " return -1;\n"
2460  " }\n"
2461  " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2462  " fclose(f);\n"
2463  " return wrote == count ? 0 : -1;\n"
2464  "}\n\n"
2465  "static int write_float_scalar(const char *path, float v)\n"
2466  "{\n"
2467  " if (!path) return -1;\n"
2468  " FILE *f = fopen(path, \"wb\");\n"
2469  " if (!f) {\n"
2470  " perror(\"fopen\");\n"
2471  " return -1;\n"
2472  " }\n"
2473  " size_t wrote = fwrite(&v, sizeof(float), 1, f);\n"
2474  " fclose(f);\n"
2475  " return wrote == 1 ? 0 : -1;\n"
2476  "}\n\n"
2477  "static void lm_head_forward(const float *hidden,\n"
2478  " const float *weights,\n"
2479  " float *logits,\n"
2480  " int T, int V, int D, int aligned_D)\n"
2481  "{\n"
2482  " for (int t = 0; t < T; ++t) {\n"
2483  " const float *h = hidden + (size_t)t * aligned_D;\n"
2484  " float *out = logits + (size_t)t * V;\n"
2485  " for (int v = 0; v < V; ++v) {\n"
2486  " const float *w = weights + (size_t)v * aligned_D;\n"
2487  " float sum = 0.0f;\n"
2488  " for (int d = 0; d < D; ++d) {\n"
2489  " sum += h[d] * w[d];\n"
2490  " }\n"
2491  " out[v] = sum;\n"
2492  " }\n"
2493  " }\n"
2494  "}\n\n"
2495  "static void softmax_cross_entropy(const float *logits,\n"
2496  " const int32_t *targets,\n"
2497  " int T, int V,\n"
2498  " float *d_logits,\n"
2499  " float *loss_out)\n"
2500  "{\n"
2501  " double total = 0.0;\n"
2502  " for (int t = 0; t < T; ++t) {\n"
2503  " const float *row = logits + (size_t)t * V;\n"
2504  " float *drow = d_logits + (size_t)t * V;\n"
2505  " int target = targets[t];\n"
2506  " float max_logit = row[0];\n"
2507  " for (int v = 1; v < V; ++v) {\n"
2508  " if (row[v] > max_logit) max_logit = row[v];\n"
2509  " }\n"
2510  " double sum_exp = 0.0;\n"
2511  " for (int v = 0; v < V; ++v) {\n"
2512  " drow[v] = expf(row[v] - max_logit);\n"
2513  " sum_exp += drow[v];\n"
2514  " }\n"
2515  " float inv_sum = 1.0f / (float)sum_exp;\n"
2516  " for (int v = 0; v < V; ++v) {\n"
2517  " drow[v] *= inv_sum;\n"
2518  " }\n"
2519  " double logsum = (double)max_logit + log(sum_exp);\n"
2520  " total += logsum - (double)row[target];\n"
2521  " drow[target] -= 1.0f;\n"
2522  " float scale = 1.0f / (float)T;\n"
2523  " for (int v = 0; v < V; ++v) {\n"
2524  " drow[v] *= scale;\n"
2525  " }\n"
2526  " }\n"
2527  " if (loss_out) {\n"
2528  " *loss_out = (float)(total / (double)T);\n"
2529  " }\n"
2530  "}\n\n"
2531  "static void lm_head_backward(const float *hidden,\n"
2532  " const float *weights,\n"
2533  " const float *d_logits,\n"
2534  " float *d_hidden,\n"
2535  " float *d_weights,\n"
2536  " int T, int V, int D, int aligned_D)\n"
2537  "{\n"
2538  " size_t dh_count = (size_t)T * aligned_D;\n"
2539  " size_t dw_count = (size_t)V * aligned_D;\n"
2540  " for (size_t i = 0; i < dh_count; ++i) d_hidden[i] = 0.0f;\n"
2541  " for (size_t i = 0; i < dw_count; ++i) d_weights[i] = 0.0f;\n"
2542  " for (int t = 0; t < T; ++t) {\n"
2543  " const float *dlog = d_logits + (size_t)t * V;\n"
2544  " for (int d = 0; d < D; ++d) {\n"
2545  " double sum = 0.0;\n"
2546  " for (int v = 0; v < V; ++v) {\n"
2547  " sum += (double)dlog[v] * (double)weights[(size_t)v * aligned_D + d];\n"
2548  " }\n"
2549  " d_hidden[(size_t)t * aligned_D + d] = (float)sum;\n"
2550  " }\n"
2551  " }\n"
2552  " for (int v = 0; v < V; ++v) {\n"
2553  " float *dw = d_weights + (size_t)v * aligned_D;\n"
2554  " for (int d = 0; d < D; ++d) {\n"
2555  " double sum = 0.0;\n"
2556  " for (int t = 0; t < T; ++t) {\n"
2557  " sum += (double)d_logits[(size_t)t * V + v] * (double)hidden[(size_t)t * aligned_D + d];\n"
2558  " }\n"
2559  " dw[d] = (float)sum;\n"
2560  " }\n"
2561  " }\n"
2562  "}\n\n");
2563 
2564  fprintf(out,
2565  "static void dump_layer_offsets(const TransformerModel *m, int layer)\n"
2566  "{\n"
2567  " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2568  " printf(\"Layer %%d offsets (bytes):\\n\", layer);\n"
2569  " printf(\" ln1_gamma=%%zu ln2_gamma=%%zu wq=%%zu wk=%%zu wv=%%zu wo=%%zu w1=%%zu w2=%%zu\\n\",\n"
2570  " L->ln1_gamma_offset, L->ln2_gamma_offset, L->wq_offset, L->wk_offset,\n"
2571  " L->wv_offset, L->wo_offset, L->w1_offset, L->w2_offset);\n"
2572  " printf(\" ln1_out=%%zu q=%%zu k=%%zu v=%%zu scores=%%zu attn_out=%%zu\\n\",\n"
2573  " L->ln1_out_offset, L->q_offset, L->k_offset, L->v_offset,\n"
2574  " L->scores_offset, L->attn_out_offset);\n"
2575  " printf(\" proj_tmp=%%zu residual1=%%zu ln2_out=%%zu fc1_out=%%zu swiglu_out=%%zu mlp_out=%%zu output=%%zu\\n\",\n"
2576  " L->proj_tmp_offset, L->residual1_offset, L->ln2_out_offset,\n"
2577  " L->fc1_out_offset, L->swiglu_out_offset, L->mlp_out_offset, L->output_offset);\n"
2578  "}\n\n"
2579  "static void dump_layout(const TransformerModel *m, int dump_all)\n"
2580  "{\n"
2581  " size_t bytes = m->total_bytes;\n"
2582  " printf(\"Model config:\\n\");\n"
2583  " printf(\" layers=%%d embed=%%d intermediate=%%d heads=%%d kv_heads=%%d\\n\",\n"
2584  " m->num_layers, m->embed_dim, m->intermediate_size, m->num_attention_heads, m->num_kv_heads);\n"
2585  " printf(\" head_dim=%%d vocab=%%d ctx=%%d cores=%%d\\n\",\n"
2586  " m->head_dim, m->vocab_size, m->context_window, m->num_cores);\n"
2587  " printf(\" eps=%%.6g rope_theta=%%.6g\\n\", m->rms_norm_eps, m->rope_theta);\n"
2588  " printf(\"Aligned dims (elements): embed=%%zu head=%%zu ctx=%%zu\\n\",\n"
2589  " m->aligned_embed_dim, m->aligned_head_dim, m->aligned_attn_context_window);\n"
2590  " printf(\"Memory: total_bytes=%%zu\\n\", bytes);\n"
2591  " printf(\"Global offsets (bytes): token=%%zu pos=%%zu embedded=%%zu layers_start=%%zu\\n\",\n"
2592  " m->token_emb_offset, m->pos_emb_offset, m->embedded_input_offset, m->layers_start_offset);\n"
2593  " printf(\"Final offsets (bytes): final_ln_w=%%zu final_ln_b=%%zu final_ln_mean=%%zu final_ln_rstd=%%zu\\n\",\n"
2594  " m->final_ln_weight_offset, m->final_ln_bias_offset,\n"
2595  " m->final_ln_mean_offset, m->final_ln_rstd_offset);\n"
2596  " printf(\"LM/logits offsets (bytes): lm_head=%%zu logits=%%zu\\n\",\n"
2597  " m->lm_head_weight_offset, m->logits_offset);\n"
2598  " if (m->num_layers > 0) {\n"
2599  " dump_layer_offsets(m, 0);\n"
2600  " if (dump_all) {\n"
2601  " for (int i = 1; i < m->num_layers; ++i) {\n"
2602  " dump_layer_offsets(m, i);\n"
2603  " }\n"
2604  " }\n"
2605  " }\n"
2606  "}\n\n");
2607 
2608  /* Emit either main() for standalone or API for library mode */
2609  if (mode == CK_EMIT_STANDALONE) {
2610  fprintf(out,
2611  "int main(int argc, char **argv)\n"
2612  "{\n"
2613  " int dump = 0;\n"
2614  " int dump_all = 0;\n"
2615  " int no_forward = 0;\n"
2616  " int run_litmus = 0;\n"
2617  " int run_backward = 0;\n"
2618  " const char *litmus_hidden = NULL;\n"
2619  " const char *litmus_weights = NULL;\n"
2620  " const char *litmus_targets = NULL;\n"
2621  " const char *model_weights = NULL;\n"
2622  " const char *tokens_path = NULL;\n"
2623  " const char *out_logits = NULL;\n"
2624  " const char *out_dlogits = NULL;\n"
2625  " const char *out_dhidden = NULL;\n"
2626  " const char *out_dweights = NULL;\n"
2627  " const char *out_loss = NULL;\n"
2628  " const char *out_weights = NULL;\n"
2629  " int steps = 1;\n"
2630  " int log_steps = 0;\n"
2631  " int strict = 0;\n"
2632  " int32_t *tokens = NULL;\n"
2633  " int32_t *targets = NULL;\n"
2634  " TransformerModel m = {0};\n"
2635  " memcpy(m.magic, \"BUMPWGT3\", 8);\n"
2636  " m.version = 3;\n"
2637  " m.model_type = 0;\n"
2638  " m.num_layers = %d;\n"
2639  " m.embed_dim = %d;\n"
2640  " m.intermediate_size = %d;\n"
2641  " m.num_attention_heads = %d;\n"
2642  " m.num_kv_heads = %d;\n"
2643  " m.vocab_size = %d;\n"
2644  " m.context_window = %d;\n"
2645  " m.rms_norm_eps = %.9g;\n"
2646  " m.rope_theta = %.9g;\n"
2647  " m.num_cores = 1;\n"
2648  " m.task_type = TASK_LM;\n"
2649  " m.optimizer = OPTIMIZER_SGD;\n"
2650  " m.learning_rate = 0.0f;\n"
2651  " for (int i = 1; i < argc; ++i) {\n"
2652  " if (strcmp(argv[i], \"--dump\") == 0) {\n"
2653  " dump = 1;\n"
2654  " continue;\n"
2655  " }\n"
2656  " if (strcmp(argv[i], \"--dump-all\") == 0) {\n"
2657  " dump = 1;\n"
2658  " dump_all = 1;\n"
2659  " continue;\n"
2660  " }\n"
2661  " if (strcmp(argv[i], \"--no-forward\") == 0) {\n"
2662  " no_forward = 1;\n"
2663  " continue;\n"
2664  " }\n"
2665  " if (strcmp(argv[i], \"--strict\") == 0) {\n"
2666  " strict = 1;\n"
2667  " continue;\n"
2668  " }\n"
2669  " if (strcmp(argv[i], \"--litmus\") == 0) {\n"
2670  " run_litmus = 1;\n"
2671  " continue;\n"
2672  " }\n"
2673  " if (strcmp(argv[i], \"--backward\") == 0) {\n"
2674  " run_backward = 1;\n"
2675  " continue;\n"
2676  " }\n"
2677  " if (strcmp(argv[i], \"--lr\") == 0 && i + 1 < argc) {\n"
2678  " parse_float_arg(argv[++i], &m.learning_rate);\n"
2679  " continue;\n"
2680  " }\n"
2681  " if (strcmp(argv[i], \"--help\") == 0) {\n"
2682  " print_usage(argv[0]);\n"
2683  " return 0;\n"
2684  " }\n"
2685  " if (strcmp(argv[i], \"--hidden\") == 0 && i + 1 < argc) {\n"
2686  " litmus_hidden = argv[++i];\n"
2687  " continue;\n"
2688  " }\n"
2689  " if (strcmp(argv[i], \"--weights\") == 0 && i + 1 < argc) {\n"
2690  " litmus_weights = argv[++i];\n"
2691  " continue;\n"
2692  " }\n"
2693  " if (strcmp(argv[i], \"--targets\") == 0 && i + 1 < argc) {\n"
2694  " litmus_targets = argv[++i];\n"
2695  " continue;\n"
2696  " }\n"
2697  " if (strcmp(argv[i], \"--model-weights\") == 0 && i + 1 < argc) {\n"
2698  " model_weights = argv[++i];\n"
2699  " continue;\n"
2700  " }\n"
2701  " if (strcmp(argv[i], \"--tokens\") == 0 && i + 1 < argc) {\n"
2702  " tokens_path = argv[++i];\n"
2703  " continue;\n"
2704  " }\n"
2705  " if (strcmp(argv[i], \"--out-logits\") == 0 && i + 1 < argc) {\n"
2706  " out_logits = argv[++i];\n"
2707  " continue;\n"
2708  " }\n"
2709  " if (strcmp(argv[i], \"--out-dlogits\") == 0 && i + 1 < argc) {\n"
2710  " out_dlogits = argv[++i];\n"
2711  " continue;\n"
2712  " }\n"
2713  " if (strcmp(argv[i], \"--out-dhidden\") == 0 && i + 1 < argc) {\n"
2714  " out_dhidden = argv[++i];\n"
2715  " continue;\n"
2716  " }\n"
2717  " if (strcmp(argv[i], \"--out-dweights\") == 0 && i + 1 < argc) {\n"
2718  " out_dweights = argv[++i];\n"
2719  " continue;\n"
2720  " }\n"
2721  " if (strcmp(argv[i], \"--out-loss\") == 0 && i + 1 < argc) {\n"
2722  " out_loss = argv[++i];\n"
2723  " continue;\n"
2724  " }\n"
2725  " if (strcmp(argv[i], \"--out-weights\") == 0 && i + 1 < argc) {\n"
2726  " out_weights = argv[++i];\n"
2727  " continue;\n"
2728  " }\n"
2729  " if (strcmp(argv[i], \"--steps\") == 0 && i + 1 < argc) {\n"
2730  " parse_int_arg(argv[++i], &steps);\n"
2731  " continue;\n"
2732  " }\n"
2733  " if (strcmp(argv[i], \"--log-steps\") == 0) {\n"
2734  " log_steps = 1;\n"
2735  " continue;\n"
2736  " }\n"
2737  " if (strcmp(argv[i], \"--layers\") == 0 && i + 1 < argc) {\n"
2738  " parse_int_arg(argv[++i], &m.num_layers);\n"
2739  " continue;\n"
2740  " }\n"
2741  " if (strcmp(argv[i], \"--embed\") == 0 && i + 1 < argc) {\n"
2742  " parse_int_arg(argv[++i], &m.embed_dim);\n"
2743  " continue;\n"
2744  " }\n"
2745  " if (strcmp(argv[i], \"--intermediate\") == 0 && i + 1 < argc) {\n"
2746  " parse_int_arg(argv[++i], &m.intermediate_size);\n"
2747  " continue;\n"
2748  " }\n"
2749  " if (strcmp(argv[i], \"--heads\") == 0 && i + 1 < argc) {\n"
2750  " parse_int_arg(argv[++i], &m.num_attention_heads);\n"
2751  " continue;\n"
2752  " }\n"
2753  " if (strcmp(argv[i], \"--kv-heads\") == 0 && i + 1 < argc) {\n"
2754  " parse_int_arg(argv[++i], &m.num_kv_heads);\n"
2755  " continue;\n"
2756  " }\n"
2757  " if (strcmp(argv[i], \"--vocab\") == 0 && i + 1 < argc) {\n"
2758  " parse_int_arg(argv[++i], &m.vocab_size);\n"
2759  " continue;\n"
2760  " }\n"
2761  " if (strcmp(argv[i], \"--ctx\") == 0 && i + 1 < argc) {\n"
2762  " parse_int_arg(argv[++i], &m.context_window);\n"
2763  " continue;\n"
2764  " }\n"
2765  " if (strcmp(argv[i], \"--cores\") == 0 && i + 1 < argc) {\n"
2766  " parse_int_arg(argv[++i], &m.num_cores);\n"
2767  " continue;\n"
2768  " }\n"
2769  " fprintf(stderr, \"Unknown or invalid arg: %%s\\n\", argv[i]);\n"
2770  " print_usage(argv[0]);\n"
2771  " return 1;\n"
2772  " }\n"
2773  " if (strict) {\n"
2774  " ck_set_strict_parity(1);\n"
2775  " }\n"
2776  " if (run_backward && m.learning_rate == 0.0f) {\n"
2777  " m.learning_rate = 1e-3f;\n"
2778  " }\n"
2779  " m.training_enabled = run_backward;\n"
2780  " m.weight_dtype = CK_DT_FP32;\n"
2781  " {\n"
2782  " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
2783  " if (wd) {\n"
2784  " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
2785  " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
2786  " m.weight_dtype = CK_DT_Q4_K;\n"
2787  " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
2788  " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
2789  " m.weight_dtype = CK_DT_Q6_K;\n"
2790  " }\n"
2791  " }\n"
2792  " }\n"
2793  " init_weight_dtypes_uniform(&m, m.weight_dtype);\n"
2794  " refresh_weight_flags(&m);\n"
2795  " if (model_weights) {\n"
2796  " int dtype_rc = load_weight_dtypes(model_weights, &m);\n"
2797  " if (dtype_rc < 0) {\n"
2798  " fprintf(stderr, \"failed to read weight dtype table\\n\");\n"
2799  " return 1;\n"
2800  " }\n"
2801  " }\n"
2802  " if (m.training_enabled && m.weights_quantized) {\n"
2803  " fprintf(stderr, \"Quantized weights are inference-only; disable training\\n\");\n"
2804  " return 1;\n"
2805  " }\n"
2806  " if (layout_model(&m) != 0) {\n"
2807  " fprintf(stderr, \"layout_model failed\\n\");\n"
2808  " return 1;\n"
2809  " }\n"
2810  " if (model_weights) {\n"
2811  " if (load_model_weights(model_weights, &m) != 0) {\n"
2812  " fprintf(stderr, \"failed to load model weights\\n\");\n"
2813  " return 1;\n"
2814  " }\n"
2815  " }\n"
2816  " if (tokens_path) {\n"
2817  " int T = m.context_window;\n"
2818  " tokens = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2819  " if (!tokens) {\n"
2820  " fprintf(stderr, \"failed to alloc tokens\\n\");\n"
2821  " return 1;\n"
2822  " }\n"
2823  " if (read_ints(tokens_path, tokens, (size_t)T) != 0) {\n"
2824  " fprintf(stderr, \"failed to read tokens\\n\");\n"
2825  " free(tokens);\n"
2826  " tokens = NULL;\n"
2827  " return 1;\n"
2828  " }\n"
2829  " if (!run_backward) {\n"
2830  " embed_tokens(&m, tokens, T);\n"
2831  " free(tokens);\n"
2832  " tokens = NULL;\n"
2833  " }\n"
2834  " }\n"
2835  " if (run_backward) {\n"
2836  " if (!litmus_targets) {\n"
2837  " fprintf(stderr, \"backward requires --targets\\n\");\n"
2838  " return 1;\n"
2839  " }\n"
2840  " int T = m.context_window;\n"
2841  " targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2842  " if (!targets) {\n"
2843  " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2844  " return 1;\n"
2845  " }\n"
2846  " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2847  " fprintf(stderr, \"failed to read targets\\n\");\n"
2848  " free(targets);\n"
2849  " targets = NULL;\n"
2850  " return 1;\n"
2851  " }\n"
2852  " }\n"
2853  " if (dump) {\n"
2854  " dump_layout(&m, dump_all);\n"
2855  " }\n"
2856  " if (run_litmus) {\n"
2857  " if (!litmus_hidden || !litmus_weights || !litmus_targets) {\n"
2858  " fprintf(stderr, \"litmus requires --hidden, --weights, and --targets\\n\");\n"
2859  " return 1;\n"
2860  " }\n"
2861  " int T = m.context_window;\n"
2862  " int V = m.vocab_size;\n"
2863  " int D = m.embed_dim;\n"
2864  " int aligned_D = (int)m.aligned_embed_dim;\n"
2865  " float *hidden = ptr_f32(m.memory_base, m.final_output_offset);\n"
2866  " float *weights = ptr_f32(m.memory_base, m.lm_head_weight_offset);\n"
2867  " float *logits = ptr_f32(m.memory_base, m.logits_offset);\n"
2868  " if (read_floats(litmus_hidden, hidden, (size_t)T * aligned_D) != 0) {\n"
2869  " fprintf(stderr, \"failed to read hidden\\n\");\n"
2870  " return 1;\n"
2871  " }\n"
2872  " if (read_floats(litmus_weights, weights, (size_t)V * aligned_D) != 0) {\n"
2873  " fprintf(stderr, \"failed to read weights\\n\");\n"
2874  " return 1;\n"
2875  " }\n"
2876  " int32_t *targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2877  " if (!targets) {\n"
2878  " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2879  " return 1;\n"
2880  " }\n"
2881  " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2882  " fprintf(stderr, \"failed to read targets\\n\");\n"
2883  " free(targets);\n"
2884  " return 1;\n"
2885  " }\n"
2886  " float *d_logits = (float *)calloc((size_t)T * V, sizeof(float));\n"
2887  " float *d_hidden = (float *)calloc((size_t)T * aligned_D, sizeof(float));\n"
2888  " float *d_weights = (float *)calloc((size_t)V * aligned_D, sizeof(float));\n"
2889  " if (!d_logits || !d_hidden || !d_weights) {\n"
2890  " fprintf(stderr, \"failed to alloc grads\\n\");\n"
2891  " free(targets);\n"
2892  " free(d_logits);\n"
2893  " free(d_hidden);\n"
2894  " free(d_weights);\n"
2895  " return 1;\n"
2896  " }\n"
2897  " lm_head_forward(hidden, weights, logits, T, V, D, aligned_D);\n"
2898  " float loss = 0.0f;\n"
2899  " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2900  " lm_head_backward(hidden, weights, d_logits, d_hidden, d_weights, T, V, D, aligned_D);\n"
2901  " if (out_logits) write_floats(out_logits, logits, (size_t)T * V);\n"
2902  " if (out_dlogits) write_floats(out_dlogits, d_logits, (size_t)T * V);\n"
2903  " if (out_dhidden) write_floats(out_dhidden, d_hidden, (size_t)T * aligned_D);\n"
2904  " if (out_dweights) write_floats(out_dweights, d_weights, (size_t)V * aligned_D);\n"
2905  " if (out_loss) write_float_scalar(out_loss, loss);\n"
2906  " if (!out_loss) printf(\"loss=%%.6f\\n\", loss);\n"
2907  " free(targets);\n"
2908  " free(d_logits);\n"
2909  " free(d_hidden);\n"
2910  " free(d_weights);\n"
2911  " ck_huge_free(m.memory_base, m.total_bytes);\n"
2912  " free(m.layers);\n"
2913  " return 0;\n"
2914  " }\n"
2915  " // TODO: load weights into m.memory_base using the offsets above.\n"
2916  " // TODO: write token/pos embeddings into embedded_input_offset.\n"
2917  " if (!run_backward) {\n"
2918  " if (!no_forward) {\n"
2919  " run_model_forward(&m);\n"
2920  " }\n"
2921  " } else {\n"
2922  " if (!tokens || !targets) {\n"
2923  " fprintf(stderr, \"backward requires --tokens and --targets\\n\");\n"
2924  " return 1;\n"
2925  " }\n"
2926  " if (steps < 1) steps = 1;\n"
2927  " float loss = 0.0f;\n"
2928  " for (int step = 0; step < steps; ++step) {\n"
2929  " embed_tokens(&m, tokens, m.context_window);\n"
2930  " run_model_forward(&m);\n"
2931  " if (run_model_backward(&m, tokens, targets, &loss) != 0) {\n"
2932  " fprintf(stderr, \"backward failed\\n\");\n"
2933  " return 1;\n"
2934  " }\n"
2935  " if (log_steps) {\n"
2936  " printf(\"step %%d loss=%%.6f\\n\", step, loss);\n"
2937  " }\n"
2938  " }\n"
2939  " if (out_loss) {\n"
2940  " write_float_scalar(out_loss, loss);\n"
2941  " }\n"
2942  " }\n"
2943  " if (out_logits) {\n"
2944  " write_floats(out_logits, ptr_f32(m.memory_base, m.logits_offset),\n"
2945  " (size_t)m.context_window * (size_t)m.vocab_size);\n"
2946  " }\n"
2947  " if (out_weights) {\n"
2948  " if (save_model_weights(out_weights, &m) != 0) {\n"
2949  " fprintf(stderr, \"failed to save model weights\\n\");\n"
2950  " return 1;\n"
2951  " }\n"
2952  " }\n"
2953  " ck_huge_free(m.memory_base, m.total_bytes);\n"
2954  " free(m.layers);\n"
2955  " free(tokens);\n"
2956  " free(targets);\n"
2957  " return 0;\n"
2958  "}\n",
2959  forward->config.num_layers,
2960  forward->config.hidden_size,
2961  forward->config.intermediate_size,
2962  forward->config.num_heads,
2963  forward->config.num_kv_heads,
2964  forward->config.vocab_size,
2965  forward->config.context_window,
2966  forward->config.rms_norm_eps,
2967  forward->config.rope_theta);
2968  } else {
2969  /* Library mode - emit API functions instead of main() */
2970  emit_library_api(out, forward);
2971  }
2972 
2973  fclose(out);
2974  if (emit_kernel_manifest(forward, path) != 0) {
2975  return -1;
2976  }
2977  return 0;
2978 }
static int emit_runtime_preamble(FILE *out)
static int emit_kernel_manifest(const CKIRGraph *forward, const char *runtime_path)
static const char * ck_first_layer_buffer_name(void)
static void emit_library_api(FILE *out, const CKIRGraph *forward)
static void emit_layer_allocations(FILE *out)
static void emit_sgd_update(FILE *out)
static void emit_model_struct(FILE *out)
static void emit_global_allocations(FILE *out)
static void emit_zero_grad(FILE *out)
static void emit_layer_offsets_struct(FILE *out)
static void emit_global_aliases_to_layer(FILE *out)
int ck_ir_validate_supported(const CKIRGraph *graph)

References CK_EMIT_STANDALONE, ck_first_layer_buffer_name(), ck_ir_validate_supported(), CKIRGraph::config, CKModelConfig::context_window, emit_global_aliases_to_layer(), emit_global_allocations(), emit_kernel_manifest(), emit_layer_allocations(), emit_layer_offsets_struct(), emit_library_api(), emit_model_struct(), emit_runtime_preamble(), emit_sgd_update(), emit_zero_grad(), CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, and CKModelConfig::vocab_size.

Referenced by main().