← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_codegen_v6.5.c File Reference
#include "ckernel_codegen.h"
#include "ckernel_registry.h"
#include "ckernel_kernel_specs.h"
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

Go to the source code of this file.

Functions

static int ck_buffer_should_alloc (const CKBufferSpec *spec)
 
static int ck_buffer_uses_weight_dtype (const CKBufferSpec *spec)
 
void ck_codegen_c_skeleton (const CKIRGraph *forward, const CKIRGraph *backward, FILE *out)
 
int ck_codegen_emit_runtime (const CKIRGraph *forward, const char *path, CKEmitMode mode)
 
static const CKBufferSpecck_find_buffer_spec (const char *name)
 
static const CKKernelSpecck_find_kernel_spec (const char *name)
 
static const char * ck_first_layer_buffer_name (void)
 
static int ck_plan_step_enabled (const CKPlanStep *step, const CKIRGraph *cfg)
 
static const char * ck_weight_dtype_expr (const CKBufferSpec *spec)
 
static void emit_bump_bytes_assignment (FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
 
static void emit_bump_bytes_assignment_weight_dtype (FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape, const char *dtype_expr)
 
static void emit_dim_expr (FILE *out, CKDimKind dim)
 
static void emit_global_aliases_to_layer (FILE *out)
 
static void emit_global_allocations (FILE *out)
 
static void emit_global_offset_fields (FILE *out)
 
static int emit_kernel_manifest (const CKIRGraph *forward, const char *runtime_path)
 
static void emit_layer_allocations (FILE *out)
 
static void emit_layer_offsets_struct (FILE *out)
 
static void emit_library_api (FILE *out, const CKIRGraph *forward)
 
static void emit_model_struct (FILE *out)
 
static void emit_offset_field (FILE *out, const char *name)
 
static int emit_plan_sources (FILE *f, const CKPlanStep *plan, size_t plan_count, const CKIRGraph *cfg, const char **seen, size_t *seen_count, size_t seen_cap)
 
static int emit_runtime_preamble (FILE *out)
 
static void emit_sgd_update (FILE *out)
 
static void emit_shape_expr (FILE *out, const CKDimToken *shape)
 
static void emit_training_conditional_assignment (FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
 
static int emit_unique_source (FILE *f, const char *path, const char **seen, size_t *seen_count, size_t seen_cap)
 
static void emit_zero_grad (FILE *out)
 
static const char * op_name (CKOpType op)
 

Function Documentation

◆ ck_buffer_should_alloc()

static int ck_buffer_should_alloc ( const CKBufferSpec spec)
static

◆ ck_buffer_uses_weight_dtype()

static int ck_buffer_uses_weight_dtype ( const CKBufferSpec spec)
static

Definition at line 62 of file ckernel_codegen_v6.5.c.

63 {
64  if (!spec || spec->role != CK_ROLE_WEIGHT || !spec->name) {
65  return 0;
66  }
67 
68  /* Weight-only quantization targets: big GEMM weights and tied embeddings.
69  * Small vectors (norm scales, biases) remain fp32 for simplicity. */
70  return (strcmp(spec->name, "token_emb") == 0) ||
71  (strcmp(spec->name, "wq") == 0) ||
72  (strcmp(spec->name, "wk") == 0) ||
73  (strcmp(spec->name, "wv") == 0) ||
74  (strcmp(spec->name, "wo") == 0) ||
75  (strcmp(spec->name, "w1") == 0) ||
76  (strcmp(spec->name, "w2") == 0) ||
77  (strcmp(spec->name, "lm_head_weight") == 0);
78 }
@ CK_ROLE_WEIGHT
const char * name

References CK_ROLE_WEIGHT, CKBufferSpec::name, and CKBufferSpec::role.

Referenced by emit_global_allocations(), and emit_layer_allocations().

◆ ck_codegen_c_skeleton()

void ck_codegen_c_skeleton ( const CKIRGraph forward,
const CKIRGraph backward,
FILE *  out 
)

Emit a C skeleton for forward + backward execution based on the IR.

This does not yet generate full pointer arithmetic or memory planning. It is intended as a starting point that:

  • Defines a model config / runtime context
  • Shows a per-layer forward loop over IR nodes
  • Sketches a backward loop over the backward IR

Definition at line 613 of file ckernel_codegen_v6.5.c.

616 {
617  if (!forward || !out) {
618  return;
619  }
620 
621  fprintf(out,
622  "/* Auto-generated skeleton from CKIRGraph.\n"
623  " * This file sketches the structure of the forward and backward\n"
624  " * execution for a decoder-only transformer. It is NOT yet a\n"
625  " * complete, runnable implementation. You can use it as a\n"
626  " * starting point to wire buffers, kernel calls, and memory layout.\n"
627  " */\n\n");
628 
629  fprintf(out, "#include \"ckernel_engine.h\"\n");
630  fprintf(out, "#include \"ckernel_model.h\"\n");
631  fprintf(out, "#include \"ckernel_alloc.h\"\n\n");
632 
633  /* Forward function */
634  fprintf(out,
635  "void run_decoder_forward(TransformerModel *model /*, inputs, etc. */)\n"
636  "{\n"
637  " for (int layer = 0; layer < model->cfg.num_layers; ++layer) {\n"
638  " /* Forward pass for layer */\n");
639 
640  int nodes_per_layer = 0;
641  if (forward->num_nodes > 0) {
642  int l0 = forward->nodes[0].id.layer;
643  for (int i = 0; i < forward->num_nodes; ++i) {
644  if (forward->nodes[i].id.layer != l0) {
645  break;
646  }
647  nodes_per_layer++;
648  }
649  }
650 
651  if (nodes_per_layer <= 0) {
652  nodes_per_layer = forward->num_nodes;
653  }
654 
655  fprintf(out, " /* This layer has %d IR nodes */\n", nodes_per_layer);
656 
657  for (int i = 0; i < nodes_per_layer; ++i) {
658  const CKIRNode *n = &forward->nodes[i];
659  fprintf(out, " // L%%d: %s\n", op_name(n->op));
660  fprintf(out,
661  " // outputs: [");
662  for (int o = 0; o < n->n_outputs; ++o) {
663  if (o > 0) fprintf(out, ", ");
664  fprintf(out, "L%%d:N%d:%d", n->id.node, o);
665  }
666  fprintf(out, "]\n");
667  fprintf(out, " // inputs : [");
668  for (int j = 0; j < n->n_inputs; ++j) {
669  const CKInputRef *inp = &n->inputs[j];
670  if (j > 0) fprintf(out, ", ");
671  if (inp->producer.node == 0xFFFFu) {
672  fprintf(out, "IN");
673  } else {
674  fprintf(out, "L%%d:N%u:%u",
675  (unsigned)inp->producer.node,
676  (unsigned)inp->out_index);
677  }
678  }
679  fprintf(out, "]\n");
680  fprintf(out,
681  " // TODO: bind buffers/weights and call %s kernel here\n\n",
682  op_name(n->op));
683  }
684 
685  fprintf(out,
686  " } /* end for layer */\n"
687  "}\n\n");
688 
689  /* Backward skeleton */
690  if (backward && backward->nodes && backward->num_nodes > 0) {
691  fprintf(out,
692  "void run_decoder_backward(TransformerModel *model /*, grads, etc. */)\n"
693  "{\n"
694  " for (int layer = model->cfg.num_layers - 1; layer >= 0; --layer) {\n"
695  " /* Backward pass for layer */\n");
696 
697  int bwd_per_layer = 0;
698  int l0 = backward->nodes[0].id.layer;
699  for (int i = 0; i < backward->num_nodes; ++i) {
700  if (backward->nodes[i].id.layer != l0) break;
701  bwd_per_layer++;
702  }
703  if (bwd_per_layer <= 0) bwd_per_layer = backward->num_nodes;
704 
705  fprintf(out, " /* This layer has %d backward IR nodes */\n", bwd_per_layer);
706 
707  for (int i = 0; i < bwd_per_layer; ++i) {
708  const CKIRNode *n = &backward->nodes[i];
709  fprintf(out, " // L%%d: %s\n", op_name(n->op));
710  fprintf(out,
711  " // TODO: wire gradient tensors and call %s kernel here\n\n",
712  op_name(n->op));
713  }
714 
715  fprintf(out,
716  " } /* end for layer */\n"
717  "}\n\n");
718  }
719 
720  fprintf(out,
721  "int main(int argc, char **argv)\n"
722  "{\n"
723  " (void)argc; (void)argv;\n"
724  " TransformerModel model = {0};\n"
725  " model.cfg.num_layers = %d;\n"
726  " model.cfg.hidden_size = %d;\n"
727  " model.cfg.intermediate_size = %d;\n"
728  " model.cfg.num_heads = %d;\n"
729  " model.cfg.num_kv_heads = %d;\n"
730  " model.cfg.vocab_size = %d;\n"
731  " model.cfg.context_window = %d;\n"
732  " model.cfg.rms_norm_eps = %.9g;\n"
733  " model.cfg.rope_theta = %.9g;\n"
734  " layout_transformer_from_ir(&model, NULL); /* TODO: pass IR if needed */\n"
735  " size_t bytes = model.total_bytes;\n"
736  " model.memory_base = (uint8_t *)ck_huge_alloc(bytes);\n"
737  " if (!model.memory_base) {\n"
738  " fprintf(stderr, \"Failed to allocate %%zu bytes for model\\n\", bytes);\n"
739  " return 1;\n"
740  " }\n"
741  " // TODO: load weights into model.memory_base based on offsets\n"
742  " run_decoder_forward(&model);\n"
743  " // TODO: run_decoder_backward(&model) when training\n"
744  " ck_huge_free(model.memory_base, bytes);\n"
745  " return 0;\n"
746  "}\n",
747  forward->config.num_layers,
748  forward->config.hidden_size,
749  forward->config.intermediate_size,
750  forward->config.num_heads,
751  forward->config.num_kv_heads,
752  forward->config.vocab_size,
753  forward->config.context_window,
754  forward->config.rms_norm_eps,
755  forward->config.rope_theta);
756 }
static const char * op_name(CKOpType op)
CKIRNode * nodes
Definition: ckernel_ir.h:75
int num_nodes
Definition: ckernel_ir.h:74
CKModelConfig config
Definition: ckernel_ir.h:73
CKOpType op
Definition: ckernel_ir.h:65
CKInputRef inputs[4]
Definition: ckernel_ir.h:66
CKKernelId id
Definition: ckernel_ir.h:64
uint8_t n_inputs
Definition: ckernel_ir.h:67
uint8_t n_outputs
Definition: ckernel_ir.h:68
CKKernelId producer
Definition: ckernel_ir.h:59
uint8_t out_index
Definition: ckernel_ir.h:60
uint16_t node
Definition: ckernel_ir.h:55
uint16_t layer
Definition: ckernel_ir.h:54
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37
float rms_norm_eps
Definition: ckernel_ir.h:31
float rope_theta
Definition: ckernel_ir.h:32

References CKIRGraph::config, CKModelConfig::context_window, CKModelConfig::hidden_size, CKIRNode::id, CKIRNode::inputs, CKModelConfig::intermediate_size, CKKernelId::layer, CKIRNode::n_inputs, CKIRNode::n_outputs, CKKernelId::node, CKIRGraph::nodes, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKIRGraph::num_nodes, CKIRNode::op, op_name(), CKInputRef::out_index, CKInputRef::producer, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, and CKModelConfig::vocab_size.

◆ ck_codegen_emit_runtime()

int ck_codegen_emit_runtime ( const CKIRGraph forward,
const char *  path,
CKEmitMode  mode 
)

Emit a C runtime file that stitches kernels for the given forward IR.

Parameters
forwardThe forward IR graph
pathOutput file path
modeCK_EMIT_STANDALONE for executable with main(), CK_EMIT_LIBRARY for shared object with API functions

Returns 0 on success, non-zero on failure.

Definition at line 1439 of file ckernel_codegen_v6.5.c.

1440 {
1441  if (!forward || !path) {
1442  return -1;
1443  }
1444  if (ck_ir_validate_supported(forward) != 0) {
1445  return -1;
1446  }
1447 
1448  FILE *out = fopen(path, "wb");
1449  if (!out) {
1450  fprintf(stderr, "ck_codegen_emit_runtime: failed to open %s: %s\n",
1451  path, strerror(errno));
1452  return -1;
1453  }
1454 
1455  if (emit_runtime_preamble(out) != 0) {
1456  fclose(out);
1457  return -1;
1458  }
1459 
1460  fprintf(out,
1461  "typedef enum {\n"
1462  " TASK_LM = 0,\n"
1463  " TASK_SEQ_CLS = 1\n"
1464  "} TaskType;\n\n"
1465  "typedef enum {\n"
1466  " OPTIMIZER_SGD = 0,\n"
1467  " OPTIMIZER_ADAM = 1\n"
1468  "} OptimizerType;\n\n"
1469  "typedef struct {\n"
1470  " size_t total_gradient_floats;\n"
1471  "} GradientStorage;\n\n");
1472 
1474  emit_model_struct(out);
1475 
1476  fprintf(out,
1477  "static int ensure_layers_allocated(TransformerModel *m)\n"
1478  "{\n"
1479  " if (!m) return -1;\n"
1480  " if (!m->layers && m->num_layers > 0) {\n"
1481  " m->layers = (TrulyOptimalLayer *)calloc((size_t)m->num_layers, sizeof(TrulyOptimalLayer));\n"
1482  " if (!m->layers) return -1;\n"
1483  " }\n"
1484  " return 0;\n"
1485  "}\n\n"
1486  "static void init_weight_dtypes_uniform(TransformerModel *m, CKDataType dt)\n"
1487  "{\n"
1488  " if (!m) return;\n"
1489  " m->token_emb_dtype = dt;\n"
1490  " m->lm_head_weight_dtype = dt;\n"
1491  " m->pos_emb_dtype = CK_DT_FP32;\n"
1492  " if (ensure_layers_allocated(m) != 0) return;\n"
1493  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1494  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1495  " L->wq_dtype = dt;\n"
1496  " L->wk_dtype = dt;\n"
1497  " L->wv_dtype = dt;\n"
1498  " L->wo_dtype = dt;\n"
1499  " L->w1_dtype = dt;\n"
1500  " L->w2_dtype = dt;\n"
1501  " }\n"
1502  "}\n\n"
1503  "static void refresh_weight_flags(TransformerModel *m)\n"
1504  "{\n"
1505  " if (!m) return;\n"
1506  " CKDataType base = m->token_emb_dtype;\n"
1507  " int mixed = 0;\n"
1508  " int quant = ck_dtype_is_quantized(base);\n"
1509  " if (m->lm_head_weight_dtype != base) mixed = 1;\n"
1510  " if (ck_dtype_is_quantized(m->lm_head_weight_dtype)) quant = 1;\n"
1511  " if (m->layers) {\n"
1512  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1513  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1514  " if (L->wq_dtype != base || L->wk_dtype != base || L->wv_dtype != base ||\n"
1515  " L->wo_dtype != base || L->w1_dtype != base || L->w2_dtype != base) {\n"
1516  " mixed = 1;\n"
1517  " }\n"
1518  " if (ck_dtype_is_quantized(L->wq_dtype) || ck_dtype_is_quantized(L->wk_dtype) ||\n"
1519  " ck_dtype_is_quantized(L->wv_dtype) || ck_dtype_is_quantized(L->wo_dtype) ||\n"
1520  " ck_dtype_is_quantized(L->w1_dtype) || ck_dtype_is_quantized(L->w2_dtype)) {\n"
1521  " quant = 1;\n"
1522  " }\n"
1523  " }\n"
1524  " }\n"
1525  " m->weights_mixed = mixed ? true : false;\n"
1526  " m->weights_quantized = quant ? true : false;\n"
1527  " if (!mixed) {\n"
1528  " m->weight_dtype = base;\n"
1529  " }\n"
1530  "}\n\n"
1531  "static int load_weight_dtypes(const char *path, TransformerModel *m)\n"
1532  "{\n"
1533  " if (!path || !m) return -1;\n"
1534  " FILE *f = fopen(path, \"rb\");\n"
1535  " if (!f) return -1;\n"
1536  " char magic[8];\n"
1537  " if (fread(magic, 1, 8, f) != 8) {\n"
1538  " fclose(f);\n"
1539  " return -1;\n"
1540  " }\n"
1541  " if (memcmp(magic, \"BUMPWGT3\", 8) != 0) {\n"
1542  " fclose(f);\n"
1543  " return 0;\n"
1544  " }\n"
1545  " uint32_t version = 0;\n"
1546  " if (fread(&version, sizeof(uint32_t), 1, f) != 1) {\n"
1547  " fclose(f);\n"
1548  " return -1;\n"
1549  " }\n"
1550  " if (version < 3) {\n"
1551  " fclose(f);\n"
1552  " return -1;\n"
1553  " }\n"
1554  " if (fseek(f, 128, SEEK_SET) != 0) {\n"
1555  " fclose(f);\n"
1556  " return -1;\n"
1557  " }\n"
1558  " uint32_t dtype_len = 0;\n"
1559  " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) {\n"
1560  " fclose(f);\n"
1561  " return -1;\n"
1562  " }\n"
1563  " if (dtype_len == 0) {\n"
1564  " fclose(f);\n"
1565  " return -1;\n"
1566  " }\n"
1567  " uint8_t *dtype_buf = (uint8_t *)malloc(dtype_len);\n"
1568  " if (!dtype_buf) {\n"
1569  " fclose(f);\n"
1570  " return -1;\n"
1571  " }\n"
1572  " if (fread(dtype_buf, 1, dtype_len, f) != dtype_len) {\n"
1573  " free(dtype_buf);\n"
1574  " fclose(f);\n"
1575  " return -1;\n"
1576  " }\n"
1577  " fclose(f);\n"
1578  "\n"
1579  " size_t expected = (size_t)m->num_layers * 14u + 4u;\n"
1580  " if (dtype_len != expected) {\n"
1581  " free(dtype_buf);\n"
1582  " return -1;\n"
1583  " }\n"
1584  " if (ensure_layers_allocated(m) != 0) {\n"
1585  " free(dtype_buf);\n"
1586  " return -1;\n"
1587  " }\n"
1588  "\n"
1589  " size_t idx = 0;\n"
1590  " CKDataType token_dt = (CKDataType)dtype_buf[idx++];\n"
1591  " CKDataType pos_dt = (CKDataType)dtype_buf[idx++];\n"
1592  " if (pos_dt != CK_DT_FP32) {\n"
1593  " free(dtype_buf);\n"
1594  " return -1;\n"
1595  " }\n"
1596  " if (token_dt != CK_DT_FP32 && token_dt != CK_DT_Q4_K && token_dt != CK_DT_Q6_K) {\n"
1597  " free(dtype_buf);\n"
1598  " return -1;\n"
1599  " }\n"
1600  " m->token_emb_dtype = token_dt;\n"
1601  " m->lm_head_weight_dtype = token_dt;\n"
1602  " m->pos_emb_dtype = pos_dt;\n"
1603  "\n"
1604  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1605  " CKDataType ln1_dt = (CKDataType)dtype_buf[idx++];\n"
1606  " CKDataType ln2_dt = (CKDataType)dtype_buf[idx++];\n"
1607  " CKDataType wq_dt = (CKDataType)dtype_buf[idx++];\n"
1608  " CKDataType bq_dt = (CKDataType)dtype_buf[idx++];\n"
1609  " CKDataType wk_dt = (CKDataType)dtype_buf[idx++];\n"
1610  " CKDataType bk_dt = (CKDataType)dtype_buf[idx++];\n"
1611  " CKDataType wv_dt = (CKDataType)dtype_buf[idx++];\n"
1612  " CKDataType bv_dt = (CKDataType)dtype_buf[idx++];\n"
1613  " CKDataType wo_dt = (CKDataType)dtype_buf[idx++];\n"
1614  " CKDataType bo_dt = (CKDataType)dtype_buf[idx++];\n"
1615  " CKDataType w1_dt = (CKDataType)dtype_buf[idx++];\n"
1616  " CKDataType b1_dt = (CKDataType)dtype_buf[idx++];\n"
1617  " CKDataType w2_dt = (CKDataType)dtype_buf[idx++];\n"
1618  " CKDataType b2_dt = (CKDataType)dtype_buf[idx++];\n"
1619  "\n"
1620  " if (ln1_dt != CK_DT_FP32 || ln2_dt != CK_DT_FP32 ||\n"
1621  " bq_dt != CK_DT_FP32 || bk_dt != CK_DT_FP32 ||\n"
1622  " bv_dt != CK_DT_FP32 || bo_dt != CK_DT_FP32 ||\n"
1623  " b1_dt != CK_DT_FP32 || b2_dt != CK_DT_FP32) {\n"
1624  " free(dtype_buf);\n"
1625  " return -1;\n"
1626  " }\n"
1627  " if ((wq_dt != CK_DT_FP32 && wq_dt != CK_DT_Q4_K && wq_dt != CK_DT_Q6_K) ||\n"
1628  " (wk_dt != CK_DT_FP32 && wk_dt != CK_DT_Q4_K && wk_dt != CK_DT_Q6_K) ||\n"
1629  " (wv_dt != CK_DT_FP32 && wv_dt != CK_DT_Q4_K && wv_dt != CK_DT_Q6_K) ||\n"
1630  " (wo_dt != CK_DT_FP32 && wo_dt != CK_DT_Q4_K && wo_dt != CK_DT_Q6_K) ||\n"
1631  " (w1_dt != CK_DT_FP32 && w1_dt != CK_DT_Q4_K && w1_dt != CK_DT_Q6_K) ||\n"
1632  " (w2_dt != CK_DT_FP32 && w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K)) {\n"
1633  " free(dtype_buf);\n"
1634  " return -1;\n"
1635  " }\n"
1636  "\n"
1637  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1638  " L->wq_dtype = wq_dt;\n"
1639  " L->wk_dtype = wk_dt;\n"
1640  " L->wv_dtype = wv_dt;\n"
1641  " L->wo_dtype = wo_dt;\n"
1642  " L->w1_dtype = w1_dt;\n"
1643  " L->w2_dtype = w2_dt;\n"
1644  " }\n"
1645  "\n"
1646  " CKDataType final_norm_dt = (CKDataType)dtype_buf[idx++];\n"
1647  " CKDataType final_bias_dt = (CKDataType)dtype_buf[idx++];\n"
1648  " free(dtype_buf);\n"
1649  " if (final_norm_dt != CK_DT_FP32 || final_bias_dt != CK_DT_FP32) {\n"
1650  " return -1;\n"
1651  " }\n"
1652  "\n"
1653  " refresh_weight_flags(m);\n"
1654  " return 1;\n"
1655  "}\n\n"
1656  "\n"
1657  "static int layout_model(TransformerModel *m)\n"
1658  "{\n"
1659  " if (!m) return -1;\n"
1660  " if (m->num_attention_heads <= 0 || m->embed_dim <= 0) return -1;\n"
1661  " if (m->num_kv_heads <= 0) m->num_kv_heads = m->num_attention_heads;\n"
1662  " if (m->num_attention_heads %% m->num_kv_heads != 0) return -1;\n"
1663  " if (m->context_window <= 0) m->context_window = 1;\n"
1664  " if (m->vocab_size <= 0) m->vocab_size = 1;\n"
1665  " if (m->intermediate_size <= 0) return -1;\n"
1666  " m->head_dim = m->embed_dim / m->num_attention_heads;\n"
1667  " if (m->rms_norm_eps <= 0.0f) m->rms_norm_eps = 1e-5f;\n"
1668  " if (m->rope_theta < 0.0f) m->rope_theta = 0.0f;\n"
1669  " if (m->rope_theta > 0.0f && (m->head_dim %% 2 != 0)) return -1;\n"
1670  " if (m->elem_bytes == 0) m->elem_bytes = sizeof(float);\n"
1671  " size_t elem_bytes = m->elem_bytes;\n"
1672  " m->aligned_embed_dim = align_up_elems((size_t)m->embed_dim, elem_bytes, CACHELINE_BYTES);\n"
1673  " m->aligned_head_dim = align_up_elems((size_t)m->head_dim, elem_bytes, CACHELINE_BYTES);\n"
1674  " m->aligned_attn_context_window = align_up_elems((size_t)m->context_window, elem_bytes, CACHELINE_BYTES);\n"
1675  " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, elem_bytes, CACHELINE_BYTES);\n"
1676  " if (ensure_layers_allocated(m) != 0) return -1;\n"
1677  " if (m->weights_quantized) {\n"
1678  " /* K-quant weights require K dimension to be a multiple of 256. */\n"
1679  " if ((m->aligned_embed_dim %% 256) != 0) return -1;\n"
1680  " if ((aligned_intermediate_dim %% 256) != 0) return -1;\n"
1681  " int wo_quant = 0;\n"
1682  " if (m->layers) {\n"
1683  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1684  " if (ck_dtype_is_quantized(m->layers[layer].wo_dtype)) {\n"
1685  " wo_quant = 1;\n"
1686  " break;\n"
1687  " }\n"
1688  " }\n"
1689  " }\n"
1690  " if (wo_quant && (size_t)m->num_attention_heads * m->aligned_head_dim != m->aligned_embed_dim) return -1;\n"
1691  " }\n"
1692  "\n"
1693  " if (m->num_cores <= 0) m->num_cores = 1;\n"
1694  " m->tokens_per_core = (m->context_window + m->num_cores - 1) / m->num_cores;\n"
1695  "\n"
1696  " size_t off = 0;\n");
1698  fprintf(out,
1699  " m->layers_start_offset = off;\n"
1700  "\n"
1701  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1702  " TrulyOptimalLayer *L = &m->layers[layer];\n");
1704  fprintf(out,
1705  " }\n"
1706  "\n");
1707  {
1708  const char *stride_field = ck_first_layer_buffer_name();
1709  fprintf(out,
1710  " if (m->num_layers > 1) {\n"
1711  " m->layer_stride = m->layers[1].%s_offset - m->layers[0].%s_offset;\n"
1712  " } else {\n"
1713  " m->layer_stride = 0;\n"
1714  " }\n",
1715  stride_field, stride_field);
1716  }
1718  fprintf(out,
1719  " m->total_bytes = align_up_bytes(off, CACHELINE_BYTES);\n"
1720  " m->memory_base = (uint8_t *)ck_huge_alloc(m->total_bytes);\n"
1721  " if (!m->memory_base) return -1;\n"
1722  " if (m->rope_theta > 0.0f) {\n"
1723  " rope_precompute_cache(ptr_f32(m->memory_base, m->rope_cos_cache_offset),\n"
1724  " ptr_f32(m->memory_base, m->rope_sin_cache_offset),\n"
1725  " m->context_window,\n"
1726  " m->head_dim,\n"
1727  " m->rope_theta);\n"
1728  " }\n"
1729  " return 0;\n"
1730  "}\n\n");
1731 
1732  fprintf(out,
1733  "static void lm_head_forward(const float *hidden,\n"
1734  " const float *weights,\n"
1735  " float *logits,\n"
1736  " int T, int V, int D, int aligned_D);\n"
1737  "static void lm_head_backward(const float *hidden,\n"
1738  " const float *weights,\n"
1739  " const float *d_logits,\n"
1740  " float *d_hidden,\n"
1741  " float *d_weights,\n"
1742  " int T, int V, int D, int aligned_D);\n"
1743  "static void softmax_cross_entropy(const float *logits,\n"
1744  " const int32_t *targets,\n"
1745  " int T, int V,\n"
1746  " float *d_logits,\n"
1747  " float *loss_out);\n\n");
1748 
1749  fprintf(out,
1750  "static void run_model_forward(TransformerModel *m)\n"
1751  "{\n"
1752  " uint8_t *base = m->memory_base;\n"
1753  " float *current = ptr_f32(base, m->embedded_input_offset);\n"
1754  " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
1755  " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
1756  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1757  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1758  " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1759  " CKLayerForwardParamsQ4K p = {0};\n"
1760  " p.tokens = T;\n"
1761  " p.embed_dim = m->embed_dim;\n"
1762  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1763  " p.num_heads = m->num_attention_heads;\n"
1764  " p.num_kv_heads = m->num_kv_heads;\n"
1765  " p.head_dim = m->head_dim;\n"
1766  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1767  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1768  " p.intermediate_dim = m->intermediate_size;\n"
1769  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1770  " p.eps = m->rms_norm_eps;\n"
1771  " p.rope_pos_offset = 0;\n"
1772  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1773  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1774  " p.input = current;\n"
1775  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1776  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1777  " p.wq = cptr_void(base, L->wq_offset);\n"
1778  " p.bq = cptr_f32(base, L->bq_offset);\n"
1779  " p.wk = cptr_void(base, L->wk_offset);\n"
1780  " p.bk = cptr_f32(base, L->bk_offset);\n"
1781  " p.wv = cptr_void(base, L->wv_offset);\n"
1782  " p.bv = cptr_f32(base, L->bv_offset);\n"
1783  " p.wo = cptr_void(base, L->wo_offset);\n"
1784  " p.bo = cptr_f32(base, L->bo_offset);\n"
1785  " p.w1 = cptr_void(base, L->w1_offset);\n"
1786  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1787  " p.w2 = cptr_void(base, L->w2_offset);\n"
1788  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1789  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1790  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1791  " p.q = ptr_f32(base, L->q_offset);\n"
1792  " p.k = ptr_f32(base, L->k_offset);\n"
1793  " p.v = ptr_f32(base, L->v_offset);\n"
1794  " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1795  " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1796  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1797  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1798  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1799  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1800  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1801  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1802  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1803  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1804  " p.output = ptr_f32(base, L->output_offset);\n"
1805  " ck_layer_forward_rmsnorm_swiglu_q4_k(&p);\n"
1806  " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1807  " kv_cache_repack_head_major_inplace(p.k,\n"
1808  " p.num_kv_heads,\n"
1809  " T,\n"
1810  " m->kv_cache_capacity,\n"
1811  " p.aligned_head_dim);\n"
1812  " kv_cache_repack_head_major_inplace(p.v,\n"
1813  " p.num_kv_heads,\n"
1814  " T,\n"
1815  " m->kv_cache_capacity,\n"
1816  " p.aligned_head_dim);\n"
1817  " }\n"
1818  " current = p.output;\n"
1819  " } else if (m->weights_quantized) {\n"
1820  " CKLayerForwardParamsQ4K p = {0};\n"
1821  " p.tokens = T;\n"
1822  " p.embed_dim = m->embed_dim;\n"
1823  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1824  " p.num_heads = m->num_attention_heads;\n"
1825  " p.num_kv_heads = m->num_kv_heads;\n"
1826  " p.head_dim = m->head_dim;\n"
1827  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1828  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1829  " p.intermediate_dim = m->intermediate_size;\n"
1830  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1831  " p.eps = m->rms_norm_eps;\n"
1832  " p.rope_pos_offset = 0;\n"
1833  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1834  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1835  " p.input = current;\n"
1836  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1837  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1838  " p.wq = cptr_void(base, L->wq_offset);\n"
1839  " p.bq = cptr_f32(base, L->bq_offset);\n"
1840  " p.wk = cptr_void(base, L->wk_offset);\n"
1841  " p.bk = cptr_f32(base, L->bk_offset);\n"
1842  " p.wv = cptr_void(base, L->wv_offset);\n"
1843  " p.bv = cptr_f32(base, L->bv_offset);\n"
1844  " p.wo = cptr_void(base, L->wo_offset);\n"
1845  " p.bo = cptr_f32(base, L->bo_offset);\n"
1846  " p.w1 = cptr_void(base, L->w1_offset);\n"
1847  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1848  " p.w2 = cptr_void(base, L->w2_offset);\n"
1849  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1850  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1851  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1852  " p.q = ptr_f32(base, L->q_offset);\n"
1853  " p.k = ptr_f32(base, L->k_offset);\n"
1854  " p.v = ptr_f32(base, L->v_offset);\n"
1855  " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1856  " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1857  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1858  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1859  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1860  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1861  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1862  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1863  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1864  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1865  " p.output = ptr_f32(base, L->output_offset);\n"
1866  " p.wq_dtype = L->wq_dtype;\n"
1867  " p.wk_dtype = L->wk_dtype;\n"
1868  " p.wv_dtype = L->wv_dtype;\n"
1869  " p.wo_dtype = L->wo_dtype;\n"
1870  " p.w1_dtype = L->w1_dtype;\n"
1871  " p.w2_dtype = L->w2_dtype;\n"
1872  " ck_layer_forward_rmsnorm_swiglu_quant(&p);\n"
1873  " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1874  " kv_cache_repack_head_major_inplace(p.k,\n"
1875  " p.num_kv_heads,\n"
1876  " T,\n"
1877  " m->kv_cache_capacity,\n"
1878  " p.aligned_head_dim);\n"
1879  " kv_cache_repack_head_major_inplace(p.v,\n"
1880  " p.num_kv_heads,\n"
1881  " T,\n"
1882  " m->kv_cache_capacity,\n"
1883  " p.aligned_head_dim);\n"
1884  " }\n"
1885  " current = p.output;\n"
1886  " } else {\n"
1887  " CKLayerForwardParams p = {0};\n"
1888  " p.tokens = T;\n"
1889  " p.embed_dim = m->embed_dim;\n"
1890  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1891  " p.num_heads = m->num_attention_heads;\n"
1892  " p.num_kv_heads = m->num_kv_heads;\n"
1893  " p.head_dim = m->head_dim;\n"
1894  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1895  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1896  " p.intermediate_dim = m->intermediate_size;\n"
1897  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1898  " p.eps = m->rms_norm_eps;\n"
1899  " p.rope_pos_offset = 0;\n"
1900  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1901  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1902  " p.input = current;\n"
1903  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1904  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1905  " p.wq = cptr_f32(base, L->wq_offset);\n"
1906  " p.bq = cptr_f32(base, L->bq_offset);\n"
1907  " p.wk = cptr_f32(base, L->wk_offset);\n"
1908  " p.bk = cptr_f32(base, L->bk_offset);\n"
1909  " p.wv = cptr_f32(base, L->wv_offset);\n"
1910  " p.bv = cptr_f32(base, L->bv_offset);\n"
1911  " p.wo = cptr_f32(base, L->wo_offset);\n"
1912  " p.bo = cptr_f32(base, L->bo_offset);\n"
1913  " p.w1 = cptr_f32(base, L->w1_offset);\n"
1914  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1915  " p.w2 = cptr_f32(base, L->w2_offset);\n"
1916  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1917  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1918  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1919  " p.q = ptr_f32(base, L->q_offset);\n"
1920  " p.k = ptr_f32(base, L->k_offset);\n"
1921  " p.v = ptr_f32(base, L->v_offset);\n"
1922  " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1923  " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1924  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1925  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1926  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1927  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1928  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1929  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1930  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1931  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1932  " p.output = ptr_f32(base, L->output_offset);\n"
1933  " ck_layer_forward_rmsnorm_swiglu(&p);\n"
1934  " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1935  " kv_cache_repack_head_major_inplace(p.k,\n"
1936  " p.num_kv_heads,\n"
1937  " T,\n"
1938  " m->kv_cache_capacity,\n"
1939  " p.aligned_head_dim);\n"
1940  " kv_cache_repack_head_major_inplace(p.v,\n"
1941  " p.num_kv_heads,\n"
1942  " T,\n"
1943  " m->kv_cache_capacity,\n"
1944  " p.aligned_head_dim);\n"
1945  " }\n"
1946  " current = p.output;\n"
1947  " }\n"
1948  " }\n"
1949  " float *final_out = ptr_f32(base, m->final_output_offset);\n"
1950  " rmsnorm_forward(current,\n"
1951  " cptr_f32(base, m->final_ln_weight_offset),\n"
1952  " final_out,\n"
1953  " ptr_f32(base, m->final_ln_rstd_offset),\n"
1954  " T,\n"
1955  " m->embed_dim,\n"
1956  " (int)m->aligned_embed_dim,\n"
1957  " m->rms_norm_eps);\n"
1958  " if (m->vocab_size > 0) {\n"
1959  " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1960  " gemm_nt_q4_k(final_out,\n"
1961  " cptr_void(base, m->lm_head_weight_offset),\n"
1962  " NULL,\n"
1963  " ptr_f32(base, m->logits_offset),\n"
1964  " T,\n"
1965  " m->vocab_size,\n"
1966  " (int)m->aligned_embed_dim);\n"
1967  " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1968  " gemm_nt_q6_k(final_out,\n"
1969  " cptr_void(base, m->lm_head_weight_offset),\n"
1970  " NULL,\n"
1971  " ptr_f32(base, m->logits_offset),\n"
1972  " T,\n"
1973  " m->vocab_size,\n"
1974  " (int)m->aligned_embed_dim);\n"
1975  " } else {\n"
1976  " lm_head_forward(final_out,\n"
1977  " cptr_f32(base, m->lm_head_weight_offset),\n"
1978  " ptr_f32(base, m->logits_offset),\n"
1979  " T,\n"
1980  " m->vocab_size,\n"
1981  " m->embed_dim,\n"
1982  " (int)m->aligned_embed_dim);\n"
1983  " }\n"
1984  " }\n"
1985  "}\n\n");
1986 
1987  emit_zero_grad(out);
1988  emit_sgd_update(out);
1989 
1990  fprintf(out,
1991  "static int run_model_backward(TransformerModel *m,\n"
1992  " const int32_t *tokens,\n"
1993  " const int32_t *targets,\n"
1994  " float *loss_out)\n"
1995  "{\n"
1996  " if (!m || !m->training_enabled) return 0;\n"
1997  " if (!tokens || !targets) return -1;\n"
1998  " if (m->num_layers <= 0) return -1;\n"
1999  " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
2000  " int V = m->vocab_size;\n"
2001  " int D = m->embed_dim;\n"
2002  " int aligned_D = (int)m->aligned_embed_dim;\n"
2003  " uint8_t *base = m->memory_base;\n"
2004  "\n"
2005  " zero_grad(m);\n"
2006  "\n"
2007  " float *final_out = ptr_f32(base, m->final_output_offset);\n"
2008  " float *logits = ptr_f32(base, m->logits_offset);\n"
2009  " float *d_logits = ptr_f32(base, m->d_logits_offset);\n"
2010  " float *d_final_out = ptr_f32(base, m->d_final_output_offset);\n"
2011  " float *d_final_in = ptr_f32(base, m->d_final_input_offset);\n"
2012  "\n"
2013  " float loss = 0.0f;\n"
2014  " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2015  " if (loss_out) {\n"
2016  " *loss_out = loss;\n"
2017  " }\n"
2018  " lm_head_backward(final_out,\n"
2019  " cptr_f32(base, m->lm_head_weight_offset),\n"
2020  " d_logits,\n"
2021  " d_final_out,\n"
2022  " ptr_f32(base, m->d_token_emb_offset),\n"
2023  " T, V, D, aligned_D);\n"
2024  " rmsnorm_backward(d_final_out,\n"
2025  " ptr_f32(base, m->layers[m->num_layers - 1].output_offset),\n"
2026  " cptr_f32(base, m->final_ln_weight_offset),\n"
2027  " ptr_f32(base, m->final_ln_rstd_offset),\n"
2028  " d_final_in,\n"
2029  " ptr_f32(base, m->d_final_ln_weight_offset),\n"
2030  " T, D, aligned_D);\n"
2031  "\n"
2032  " for (int layer = m->num_layers - 1; layer >= 0; --layer) {\n"
2033  " TrulyOptimalLayer *L = &m->layers[layer];\n"
2034  " CKLayerBackwardParams p = {0};\n"
2035  " p.tokens = T;\n"
2036  " p.embed_dim = m->embed_dim;\n"
2037  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
2038  " p.num_heads = m->num_attention_heads;\n"
2039  " p.num_kv_heads = m->num_kv_heads;\n"
2040  " p.head_dim = m->head_dim;\n"
2041  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
2042  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
2043  " p.intermediate_dim = m->intermediate_size;\n"
2044  " p.aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2045  " p.eps = m->rms_norm_eps;\n"
2046  " p.rope_pos_offset = 0;\n"
2047  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
2048  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
2049  " p.input = (layer == 0) ? ptr_f32(base, m->embedded_input_offset)\n"
2050  " : ptr_f32(base, m->layers[layer - 1].output_offset);\n"
2051  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
2052  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
2053  " p.ln1_out = cptr_f32(base, L->ln1_out_offset);\n"
2054  " p.ln1_rstd = cptr_f32(base, L->ln1_rstd_offset);\n"
2055  " p.ln2_out = cptr_f32(base, L->ln2_out_offset);\n"
2056  " p.ln2_rstd = cptr_f32(base, L->ln2_rstd_offset);\n"
2057  " p.wq = cptr_f32(base, L->wq_offset);\n"
2058  " p.bq = cptr_f32(base, L->bq_offset);\n"
2059  " p.wk = cptr_f32(base, L->wk_offset);\n"
2060  " p.bk = cptr_f32(base, L->bk_offset);\n"
2061  " p.wv = cptr_f32(base, L->wv_offset);\n"
2062  " p.bv = cptr_f32(base, L->bv_offset);\n"
2063  " p.wo = cptr_f32(base, L->wo_offset);\n"
2064  " p.bo = cptr_f32(base, L->bo_offset);\n"
2065  " p.w1 = cptr_f32(base, L->w1_offset);\n"
2066  " p.b1 = cptr_f32(base, L->b1_offset);\n"
2067  " p.w2 = cptr_f32(base, L->w2_offset);\n"
2068  " p.b2 = cptr_f32(base, L->b2_offset);\n"
2069  " p.q = cptr_f32(base, L->q_offset);\n"
2070  " p.k = cptr_f32(base, L->k_offset);\n"
2071  " p.v = cptr_f32(base, L->v_offset);\n"
2072  " p.scores = L->scores_offset ? cptr_f32(base, L->scores_offset) : NULL;\n"
2073  " p.attn_out = cptr_f32(base, L->attn_out_offset);\n"
2074  " p.residual1 = cptr_f32(base, L->residual1_offset);\n"
2075  " p.fc1_out = cptr_f32(base, L->fc1_out_offset);\n"
2076  " p.swiglu_out = cptr_f32(base, L->swiglu_out_offset);\n"
2077  " p.d_output = ptr_f32(base, L->d_output_offset);\n"
2078  " p.d_input = ptr_f32(base, L->d_input_offset);\n"
2079  " p.d_ln1_gamma = ptr_f32(base, L->d_ln1_gamma_offset);\n"
2080  " p.d_ln2_gamma = ptr_f32(base, L->d_ln2_gamma_offset);\n"
2081  " p.d_wq = ptr_f32(base, L->d_wq_offset);\n"
2082  " p.d_bq = ptr_f32(base, L->d_bq_offset);\n"
2083  " p.d_wk = ptr_f32(base, L->d_wk_offset);\n"
2084  " p.d_bk = ptr_f32(base, L->d_bk_offset);\n"
2085  " p.d_wv = ptr_f32(base, L->d_wv_offset);\n"
2086  " p.d_bv = ptr_f32(base, L->d_bv_offset);\n"
2087  " p.d_wo = ptr_f32(base, L->d_wo_offset);\n"
2088  " p.d_bo = ptr_f32(base, L->d_bo_offset);\n"
2089  " p.d_w1 = ptr_f32(base, L->d_w1_offset);\n"
2090  " p.d_b1 = ptr_f32(base, L->d_b1_offset);\n"
2091  " p.d_w2 = ptr_f32(base, L->d_w2_offset);\n"
2092  " p.d_b2 = ptr_f32(base, L->d_b2_offset);\n"
2093  " p.d_ln1_out = ptr_f32(base, L->d_ln1_out_offset);\n"
2094  " p.d_q = ptr_f32(base, L->d_q_offset);\n"
2095  " p.d_k = ptr_f32(base, L->d_k_offset);\n"
2096  " p.d_v = ptr_f32(base, L->d_v_offset);\n"
2097  " p.d_scores = ptr_f32(base, L->d_scores_offset);\n"
2098  " p.d_attn_out = ptr_f32(base, L->d_attn_out_offset);\n"
2099  " p.d_proj_tmp = ptr_f32(base, L->d_proj_tmp_offset);\n"
2100  " p.d_residual1 = ptr_f32(base, L->d_residual1_offset);\n"
2101  " p.d_ln2_out = ptr_f32(base, L->d_ln2_out_offset);\n"
2102  " p.d_fc1_out = ptr_f32(base, L->d_fc1_out_offset);\n"
2103  " p.d_swiglu_out = ptr_f32(base, L->d_swiglu_out_offset);\n"
2104  " p.d_mlp_out = ptr_f32(base, L->d_mlp_out_offset);\n"
2105  "\n"
2106  " const float *src = (layer == m->num_layers - 1)\n"
2107  " ? d_final_in\n"
2108  " : ptr_f32(base, m->layers[layer + 1].d_input_offset);\n"
2109  " memcpy(p.d_output, src, (size_t)T * (size_t)aligned_D * sizeof(float));\n"
2110  "\n"
2111  " ck_layer_backward_rmsnorm_swiglu(&p);\n"
2112  " }\n"
2113  "\n"
2114  " {\n"
2115  " TrulyOptimalLayer *L0 = &m->layers[0];\n"
2116  " embedding_backward(tokens,\n"
2117  " T,\n"
2118  " ptr_f32(base, L0->d_input_offset),\n"
2119  " ptr_f32(base, m->d_token_emb_offset),\n"
2120  " ptr_f32(base, m->d_pos_emb_offset),\n"
2121  " m->vocab_size,\n"
2122  " m->embed_dim,\n"
2123  " aligned_D,\n"
2124  " m->context_window,\n"
2125  " m->rope_theta <= 0.0f);\n"
2126  " }\n"
2127  "\n"
2128  " /* SGD update is now called separately via optimizer_step() */\n"
2129  " return 0;\n"
2130  "}\n\n");
2131 
2132  fprintf(out,
2133  "static int parse_int_arg(const char *s, int *out)\n"
2134  "{\n"
2135  " if (!s || !out) return 0;\n"
2136  " char *end = NULL;\n"
2137  " long v = strtol(s, &end, 10);\n"
2138  " if (!end || *end != '\\0') return 0;\n"
2139  " *out = (int)v;\n"
2140  " return 1;\n"
2141  "}\n\n"
2142  "static int parse_float_arg(const char *s, float *out)\n"
2143  "{\n"
2144  " if (!s || !out) return 0;\n"
2145  " char *end = NULL;\n"
2146  " double v = strtod(s, &end);\n"
2147  " if (!end || *end != '\\0') return 0;\n"
2148  " *out = (float)v;\n"
2149  " return 1;\n"
2150  "}\n\n"
2151  "static void print_usage(const char *prog)\n"
2152  "{\n"
2153  " printf(\"Usage: %%s [options]\\n\", prog);\n"
2154  " printf(\" --dump Print layout summary (layer 0 only)\\n\");\n"
2155  " printf(\" --dump-all Print layout summary for all layers\\n\");\n"
2156  " printf(\" --no-forward Skip forward pass (layout + alloc only)\\n\");\n"
2157  " printf(\" --layers N Override num_layers\\n\");\n"
2158  " printf(\" --embed N Override embed_dim\\n\");\n"
2159  " printf(\" --intermediate N Override intermediate_size\\n\");\n"
2160  " printf(\" --heads N Override num_attention_heads\\n\");\n"
2161  " printf(\" --kv-heads N Override num_kv_heads\\n\");\n"
2162  " printf(\" --vocab N Override vocab_size\\n\");\n"
2163  " printf(\" --ctx N Override context_window\\n\");\n"
2164  " printf(\" --cores N Override num_cores\\n\");\n"
2165  " printf(\" --litmus Run LM head + CE + backward litmus\\n\");\n"
2166  " printf(\" --backward Run backward pass + SGD update (requires --tokens/--targets)\\n\");\n"
2167  " printf(\" --lr F SGD learning rate (default: 1e-3 when --backward)\\n\");\n"
2168  " printf(\" --steps N Training steps (default: 1)\\n\");\n"
2169  " printf(\" --log-steps Print loss per step during training\\n\");\n"
2170  " printf(\" --strict Enable strict parity mode (single-thread + double GEMM)\\n\");\n"
2171  " printf(\" --hidden PATH Load hidden activations [T x aligned_D] f32\\n\");\n"
2172  " printf(\" --weights PATH Load LM head weights [V x aligned_D] f32 (litmus)\\n\");\n"
2173  " printf(\" --targets PATH Load target tokens [T] int32\\n\");\n"
2174  " printf(\" --model-weights PATH Load full model weights (bump format)\\n\");\n"
2175  " printf(\" --tokens PATH Load token IDs [T] int32 and build embeddings\\n\");\n"
2176  " printf(\" --out-logits PATH Write logits [T x V] f32\\n\");\n"
2177  " printf(\" --out-dlogits PATH Write d_logits [T x V] f32\\n\");\n"
2178  " printf(\" --out-dhidden PATH Write d_hidden [T x aligned_D] f32\\n\");\n"
2179  " printf(\" --out-dweights PATH Write d_weights [V x aligned_D] f32\\n\");\n"
2180  " printf(\" --out-loss PATH Write loss (single f32)\\n\");\n"
2181  " printf(\" --out-weights PATH Write model weights (flat, no header)\\n\");\n"
2182  " printf(\" --help Show this help\\n\");\n"
2183  "}\n\n"
2184  "static int read_floats(const char *path, float *dst, size_t count)\n"
2185  "{\n"
2186  " if (!path || !dst) return -1;\n"
2187  " FILE *f = fopen(path, \"rb\");\n"
2188  " if (!f) {\n"
2189  " perror(\"fopen\");\n"
2190  " return -1;\n"
2191  " }\n"
2192  " size_t got = fread(dst, sizeof(float), count, f);\n"
2193  " fclose(f);\n"
2194  " return got == count ? 0 : -1;\n"
2195  "}\n\n"
2196  "static int read_ints(const char *path, int32_t *dst, size_t count)\n"
2197  "{\n"
2198  " if (!path || !dst) return -1;\n"
2199  " FILE *f = fopen(path, \"rb\");\n"
2200  " if (!f) {\n"
2201  " perror(\"fopen\");\n"
2202  " return -1;\n"
2203  " }\n"
2204  " size_t got = fread(dst, sizeof(int32_t), count, f);\n"
2205  " fclose(f);\n"
2206  " return got == count ? 0 : -1;\n"
2207  "}\n\n"
2208  "static int read_floats_file(FILE *f, float *dst, size_t count)\n"
2209  "{\n"
2210  " if (!f || !dst) return -1;\n"
2211  " size_t got = fread(dst, sizeof(float), count, f);\n"
2212  " return got == count ? 0 : -1;\n"
2213  "}\n\n"
2214  "static int read_bytes_file(FILE *f, void *dst, size_t bytes)\n"
2215  "{\n"
2216  " if (!f || !dst) return -1;\n"
2217  " size_t got = fread(dst, 1, bytes, f);\n"
2218  " return got == bytes ? 0 : -1;\n"
2219  "}\n\n"
2220  "static int write_floats_file(FILE *f, const float *src, size_t count)\n"
2221  "{\n"
2222  " if (!f || !src) return -1;\n"
2223  " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2224  " return wrote == count ? 0 : -1;\n"
2225  "}\n\n"
2226  "static int write_bytes_file(FILE *f, const void *src, size_t bytes)\n"
2227  "{\n"
2228  " if (!f || !src) return -1;\n"
2229  " size_t wrote = fwrite(src, 1, bytes, f);\n"
2230  " return wrote == bytes ? 0 : -1;\n"
2231  "}\n\n"
2232  "static int read_weight_file(FILE *f, CKDataType dtype, void *dst, size_t n_elements)\n"
2233  "{\n"
2234  " if (!f || !dst) return -1;\n"
2235  " if (dtype == CK_DT_FP32) {\n"
2236  " return read_floats_file(f, (float *)dst, n_elements);\n"
2237  " }\n"
2238  " return read_bytes_file(f, dst, ck_dtype_row_bytes(dtype, n_elements));\n"
2239  "}\n\n"
2240  "static int write_weight_file(FILE *f, CKDataType dtype, const void *src, size_t n_elements)\n"
2241  "{\n"
2242  " if (!f || !src) return -1;\n"
2243  " if (dtype == CK_DT_FP32) {\n"
2244  " return write_floats_file(f, (const float *)src, n_elements);\n"
2245  " }\n"
2246  " return write_bytes_file(f, src, ck_dtype_row_bytes(dtype, n_elements));\n"
2247  "}\n\n"
2248  "static int skip_bump_header(FILE *f)\n"
2249  "{\n"
2250  " if (!f) return -1;\n"
2251  " char magic[8];\n"
2252  " if (fread(magic, 1, 8, f) != 8) return -1;\n"
2253  " if (memcmp(magic, \"BUMPWGT3\", 8) == 0) {\n"
2254  " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2255  " uint32_t dtype_len = 0;\n"
2256  " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) return -1;\n"
2257  " if (fseek(f, (long)dtype_len, SEEK_CUR) != 0) return -1;\n"
2258  " return 1;\n"
2259  " }\n"
2260  " if (memcmp(magic, \"BUMPWGT2\", 8) == 0) {\n"
2261  " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2262  " return 1;\n"
2263  " }\n"
2264  " if (fseek(f, 0, SEEK_SET) != 0) return -1;\n"
2265  " return 0;\n"
2266  "}\n\n"
2267  "static int load_model_weights(const char *path, TransformerModel *m)\n"
2268  "{\n"
2269  " if (!path || !m || !m->memory_base) return -1;\n"
2270  " FILE *f = fopen(path, \"rb\");\n"
2271  " if (!f) {\n"
2272  " perror(\"fopen\");\n"
2273  " return -1;\n"
2274  " }\n"
2275  " if (skip_bump_header(f) < 0) {\n"
2276  " fclose(f);\n"
2277  " return -1;\n"
2278  " }\n"
2279  " uint8_t *base = m->memory_base;\n"
2280  " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2281  " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2282  " if (read_weight_file(f, m->token_emb_dtype, ptr_u8(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2283  " if (read_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2284  " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2285  "\n"
2286  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2287  " TrulyOptimalLayer *L = &m->layers[layer];\n"
2288  " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2289  " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2290  " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2291  " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2292  " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2293  " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2294  " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2295  " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2296  "\n"
2297  " if (read_floats_file(f, ptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2298  " if (read_floats_file(f, ptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2299  " if (read_weight_file(f, L->wq_dtype, ptr_u8(base, L->wq_offset), q_w) != 0) goto fail;\n"
2300  " if (read_floats_file(f, ptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2301  " if (read_weight_file(f, L->wk_dtype, ptr_u8(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2302  " if (read_floats_file(f, ptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2303  " if (read_weight_file(f, L->wv_dtype, ptr_u8(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2304  " if (read_floats_file(f, ptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2305  " if (read_weight_file(f, L->wo_dtype, ptr_u8(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2306  " if (read_floats_file(f, ptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2307  " if (read_weight_file(f, L->w1_dtype, ptr_u8(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2308  " if (read_floats_file(f, ptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2309  " if (read_weight_file(f, L->w2_dtype, ptr_u8(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2310  " if (read_floats_file(f, ptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2311  " }\n"
2312  "\n"
2313  " if (read_floats_file(f, ptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2314  " if (read_floats_file(f, ptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2315  "\n"
2316  " fclose(f);\n"
2317  " return 0;\n"
2318  "fail:\n"
2319  " fclose(f);\n"
2320  " return -1;\n"
2321  "}\n\n"
2322  "static int save_model_weights(const char *path, const TransformerModel *m)\n"
2323  "{\n"
2324  " if (!path || !m || !m->memory_base) return -1;\n"
2325  " FILE *f = fopen(path, \"wb\");\n"
2326  " if (!f) {\n"
2327  " perror(\"fopen\");\n"
2328  " return -1;\n"
2329  " }\n"
2330  " uint8_t *base = m->memory_base;\n"
2331  " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2332  " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2333  " if (write_weight_file(f, m->token_emb_dtype, cptr_void(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2334  " if (write_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2335  " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2336  "\n"
2337  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2338  " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2339  " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2340  " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2341  " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2342  " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2343  " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2344  " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2345  " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2346  " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2347  "\n"
2348  " if (write_floats_file(f, cptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2349  " if (write_floats_file(f, cptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2350  " if (write_weight_file(f, L->wq_dtype, cptr_void(base, L->wq_offset), q_w) != 0) goto fail;\n"
2351  " if (write_floats_file(f, cptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2352  " if (write_weight_file(f, L->wk_dtype, cptr_void(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2353  " if (write_floats_file(f, cptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2354  " if (write_weight_file(f, L->wv_dtype, cptr_void(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2355  " if (write_floats_file(f, cptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2356  " if (write_weight_file(f, L->wo_dtype, cptr_void(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2357  " if (write_floats_file(f, cptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2358  " if (write_weight_file(f, L->w1_dtype, cptr_void(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2359  " if (write_floats_file(f, cptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2360  " if (write_weight_file(f, L->w2_dtype, cptr_void(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2361  " if (write_floats_file(f, cptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2362  " }\n"
2363  "\n"
2364  " if (write_floats_file(f, cptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2365  " if (write_floats_file(f, cptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2366  "\n"
2367  " fclose(f);\n"
2368  " return 0;\n"
2369  "fail:\n"
2370  " fclose(f);\n"
2371  " return -1;\n"
2372  "}\n\n"
2373  "static void embed_tokens(const TransformerModel *m, const int32_t *tokens, int token_count)\n"
2374  "{\n"
2375  " if (!m || !m->memory_base || !tokens) return;\n"
2376  " const uint8_t *base = m->memory_base;\n"
2377  " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2378  " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2379  " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2380  " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2381  " int T = m->context_window;\n"
2382  " int D = m->embed_dim;\n"
2383  " int aligned_D = (int)m->aligned_embed_dim;\n"
2384  " for (int t = 0; t < T; ++t) {\n"
2385  " float *dst = out + (size_t)t * aligned_D;\n"
2386  " if (t < token_count) {\n"
2387  " int id = tokens[t];\n"
2388  " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2389  " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2390  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2391  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2392  " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2393  " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2394  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2395  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2396  " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2397  " } else {\n"
2398  " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2399  " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2400  " }\n"
2401  " if (aligned_D > D) {\n"
2402  " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2403  " }\n"
2404  " if (m->rope_theta <= 0.0f) {\n"
2405  " const float *p = pos + (size_t)t * aligned_D;\n"
2406  " for (int d = 0; d < D; ++d) {\n"
2407  " dst[d] += p[d];\n"
2408  " }\n"
2409  " }\n"
2410  " } else {\n"
2411  " memset(dst, 0, (size_t)aligned_D * sizeof(float));\n"
2412  " }\n"
2413  " }\n"
2414  "}\n\n"
2415  "static void embed_token_at(const TransformerModel *m, int32_t token, int t)\n"
2416  "{\n"
2417  " if (!m || !m->memory_base) return;\n"
2418  " if (t < 0 || t >= m->context_window) return;\n"
2419  " const uint8_t *base = m->memory_base;\n"
2420  " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2421  " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2422  " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2423  " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2424  " int D = m->embed_dim;\n"
2425  " int aligned_D = (int)m->aligned_embed_dim;\n"
2426  " int id = (int)token;\n"
2427  " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2428  " float *dst = out + (size_t)t * aligned_D;\n"
2429  " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2430  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2431  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2432  " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2433  " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2434  " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2435  " const void *row = tok_q + (size_t)id * row_bytes;\n"
2436  " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2437  " } else {\n"
2438  " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2439  " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2440  " }\n"
2441  " if (aligned_D > D) {\n"
2442  " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2443  " }\n"
2444  " if (m->rope_theta <= 0.0f) {\n"
2445  " const float *p = pos + (size_t)t * aligned_D;\n"
2446  " for (int d = 0; d < D; ++d) {\n"
2447  " dst[d] += p[d];\n"
2448  " }\n"
2449  " }\n"
2450  "}\n\n"
2451  "static int write_floats(const char *path, const float *src, size_t count)\n"
2452  "{\n"
2453  " if (!path || !src) return -1;\n"
2454  " FILE *f = fopen(path, \"wb\");\n"
2455  " if (!f) {\n"
2456  " perror(\"fopen\");\n"
2457  " return -1;\n"
2458  " }\n"
2459  " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2460  " fclose(f);\n"
2461  " return wrote == count ? 0 : -1;\n"
2462  "}\n\n"
2463  "static int write_float_scalar(const char *path, float v)\n"
2464  "{\n"
2465  " if (!path) return -1;\n"
2466  " FILE *f = fopen(path, \"wb\");\n"
2467  " if (!f) {\n"
2468  " perror(\"fopen\");\n"
2469  " return -1;\n"
2470  " }\n"
2471  " size_t wrote = fwrite(&v, sizeof(float), 1, f);\n"
2472  " fclose(f);\n"
2473  " return wrote == 1 ? 0 : -1;\n"
2474  "}\n\n"
2475  "static void lm_head_forward(const float *hidden,\n"
2476  " const float *weights,\n"
2477  " float *logits,\n"
2478  " int T, int V, int D, int aligned_D)\n"
2479  "{\n"
2480  " for (int t = 0; t < T; ++t) {\n"
2481  " const float *h = hidden + (size_t)t * aligned_D;\n"
2482  " float *out = logits + (size_t)t * V;\n"
2483  " for (int v = 0; v < V; ++v) {\n"
2484  " const float *w = weights + (size_t)v * aligned_D;\n"
2485  " float sum = 0.0f;\n"
2486  " for (int d = 0; d < D; ++d) {\n"
2487  " sum += h[d] * w[d];\n"
2488  " }\n"
2489  " out[v] = sum;\n"
2490  " }\n"
2491  " }\n"
2492  "}\n\n"
2493  "static void softmax_cross_entropy(const float *logits,\n"
2494  " const int32_t *targets,\n"
2495  " int T, int V,\n"
2496  " float *d_logits,\n"
2497  " float *loss_out)\n"
2498  "{\n"
2499  " double total = 0.0;\n"
2500  " for (int t = 0; t < T; ++t) {\n"
2501  " const float *row = logits + (size_t)t * V;\n"
2502  " float *drow = d_logits + (size_t)t * V;\n"
2503  " int target = targets[t];\n"
2504  " float max_logit = row[0];\n"
2505  " for (int v = 1; v < V; ++v) {\n"
2506  " if (row[v] > max_logit) max_logit = row[v];\n"
2507  " }\n"
2508  " double sum_exp = 0.0;\n"
2509  " for (int v = 0; v < V; ++v) {\n"
2510  " drow[v] = expf(row[v] - max_logit);\n"
2511  " sum_exp += drow[v];\n"
2512  " }\n"
2513  " float inv_sum = 1.0f / (float)sum_exp;\n"
2514  " for (int v = 0; v < V; ++v) {\n"
2515  " drow[v] *= inv_sum;\n"
2516  " }\n"
2517  " double logsum = (double)max_logit + log(sum_exp);\n"
2518  " total += logsum - (double)row[target];\n"
2519  " drow[target] -= 1.0f;\n"
2520  " float scale = 1.0f / (float)T;\n"
2521  " for (int v = 0; v < V; ++v) {\n"
2522  " drow[v] *= scale;\n"
2523  " }\n"
2524  " }\n"
2525  " if (loss_out) {\n"
2526  " *loss_out = (float)(total / (double)T);\n"
2527  " }\n"
2528  "}\n\n"
2529  "static void lm_head_backward(const float *hidden,\n"
2530  " const float *weights,\n"
2531  " const float *d_logits,\n"
2532  " float *d_hidden,\n"
2533  " float *d_weights,\n"
2534  " int T, int V, int D, int aligned_D)\n"
2535  "{\n"
2536  " size_t dh_count = (size_t)T * aligned_D;\n"
2537  " size_t dw_count = (size_t)V * aligned_D;\n"
2538  " for (size_t i = 0; i < dh_count; ++i) d_hidden[i] = 0.0f;\n"
2539  " for (size_t i = 0; i < dw_count; ++i) d_weights[i] = 0.0f;\n"
2540  " for (int t = 0; t < T; ++t) {\n"
2541  " const float *dlog = d_logits + (size_t)t * V;\n"
2542  " for (int d = 0; d < D; ++d) {\n"
2543  " double sum = 0.0;\n"
2544  " for (int v = 0; v < V; ++v) {\n"
2545  " sum += (double)dlog[v] * (double)weights[(size_t)v * aligned_D + d];\n"
2546  " }\n"
2547  " d_hidden[(size_t)t * aligned_D + d] = (float)sum;\n"
2548  " }\n"
2549  " }\n"
2550  " for (int v = 0; v < V; ++v) {\n"
2551  " float *dw = d_weights + (size_t)v * aligned_D;\n"
2552  " for (int d = 0; d < D; ++d) {\n"
2553  " double sum = 0.0;\n"
2554  " for (int t = 0; t < T; ++t) {\n"
2555  " sum += (double)d_logits[(size_t)t * V + v] * (double)hidden[(size_t)t * aligned_D + d];\n"
2556  " }\n"
2557  " dw[d] = (float)sum;\n"
2558  " }\n"
2559  " }\n"
2560  "}\n\n");
2561 
2562  fprintf(out,
2563  "static void dump_layer_offsets(const TransformerModel *m, int layer)\n"
2564  "{\n"
2565  " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2566  " printf(\"Layer %%d offsets (bytes):\\n\", layer);\n"
2567  " printf(\" ln1_gamma=%%zu ln2_gamma=%%zu wq=%%zu wk=%%zu wv=%%zu wo=%%zu w1=%%zu w2=%%zu\\n\",\n"
2568  " L->ln1_gamma_offset, L->ln2_gamma_offset, L->wq_offset, L->wk_offset,\n"
2569  " L->wv_offset, L->wo_offset, L->w1_offset, L->w2_offset);\n"
2570  " printf(\" ln1_out=%%zu q=%%zu k=%%zu v=%%zu scores=%%zu attn_out=%%zu\\n\",\n"
2571  " L->ln1_out_offset, L->q_offset, L->k_offset, L->v_offset,\n"
2572  " L->scores_offset, L->attn_out_offset);\n"
2573  " printf(\" proj_tmp=%%zu residual1=%%zu ln2_out=%%zu fc1_out=%%zu swiglu_out=%%zu mlp_out=%%zu output=%%zu\\n\",\n"
2574  " L->proj_tmp_offset, L->residual1_offset, L->ln2_out_offset,\n"
2575  " L->fc1_out_offset, L->swiglu_out_offset, L->mlp_out_offset, L->output_offset);\n"
2576  "}\n\n"
2577  "static void dump_layout(const TransformerModel *m, int dump_all)\n"
2578  "{\n"
2579  " size_t bytes = m->total_bytes;\n"
2580  " printf(\"Model config:\\n\");\n"
2581  " printf(\" layers=%%d embed=%%d intermediate=%%d heads=%%d kv_heads=%%d\\n\",\n"
2582  " m->num_layers, m->embed_dim, m->intermediate_size, m->num_attention_heads, m->num_kv_heads);\n"
2583  " printf(\" head_dim=%%d vocab=%%d ctx=%%d cores=%%d\\n\",\n"
2584  " m->head_dim, m->vocab_size, m->context_window, m->num_cores);\n"
2585  " printf(\" eps=%%.6g rope_theta=%%.6g\\n\", m->rms_norm_eps, m->rope_theta);\n"
2586  " printf(\"Aligned dims (elements): embed=%%zu head=%%zu ctx=%%zu\\n\",\n"
2587  " m->aligned_embed_dim, m->aligned_head_dim, m->aligned_attn_context_window);\n"
2588  " printf(\"Memory: total_bytes=%%zu\\n\", bytes);\n"
2589  " printf(\"Global offsets (bytes): token=%%zu pos=%%zu embedded=%%zu layers_start=%%zu\\n\",\n"
2590  " m->token_emb_offset, m->pos_emb_offset, m->embedded_input_offset, m->layers_start_offset);\n"
2591  " printf(\"Final offsets (bytes): final_ln_w=%%zu final_ln_b=%%zu final_ln_mean=%%zu final_ln_rstd=%%zu\\n\",\n"
2592  " m->final_ln_weight_offset, m->final_ln_bias_offset,\n"
2593  " m->final_ln_mean_offset, m->final_ln_rstd_offset);\n"
2594  " printf(\"LM/logits offsets (bytes): lm_head=%%zu logits=%%zu\\n\",\n"
2595  " m->lm_head_weight_offset, m->logits_offset);\n"
2596  " if (m->num_layers > 0) {\n"
2597  " dump_layer_offsets(m, 0);\n"
2598  " if (dump_all) {\n"
2599  " for (int i = 1; i < m->num_layers; ++i) {\n"
2600  " dump_layer_offsets(m, i);\n"
2601  " }\n"
2602  " }\n"
2603  " }\n"
2604  "}\n\n");
2605 
2606  /* Emit either main() for standalone or API for library mode */
2607  if (mode == CK_EMIT_STANDALONE) {
2608  fprintf(out,
2609  "int main(int argc, char **argv)\n"
2610  "{\n"
2611  " int dump = 0;\n"
2612  " int dump_all = 0;\n"
2613  " int no_forward = 0;\n"
2614  " int run_litmus = 0;\n"
2615  " int run_backward = 0;\n"
2616  " const char *litmus_hidden = NULL;\n"
2617  " const char *litmus_weights = NULL;\n"
2618  " const char *litmus_targets = NULL;\n"
2619  " const char *model_weights = NULL;\n"
2620  " const char *tokens_path = NULL;\n"
2621  " const char *out_logits = NULL;\n"
2622  " const char *out_dlogits = NULL;\n"
2623  " const char *out_dhidden = NULL;\n"
2624  " const char *out_dweights = NULL;\n"
2625  " const char *out_loss = NULL;\n"
2626  " const char *out_weights = NULL;\n"
2627  " int steps = 1;\n"
2628  " int log_steps = 0;\n"
2629  " int strict = 0;\n"
2630  " int32_t *tokens = NULL;\n"
2631  " int32_t *targets = NULL;\n"
2632  " TransformerModel m = {0};\n"
2633  " memcpy(m.magic, \"BUMPWGT3\", 8);\n"
2634  " m.version = 3;\n"
2635  " m.model_type = 0;\n"
2636  " m.num_layers = %d;\n"
2637  " m.embed_dim = %d;\n"
2638  " m.intermediate_size = %d;\n"
2639  " m.num_attention_heads = %d;\n"
2640  " m.num_kv_heads = %d;\n"
2641  " m.vocab_size = %d;\n"
2642  " m.context_window = %d;\n"
2643  " m.rms_norm_eps = %.9g;\n"
2644  " m.rope_theta = %.9g;\n"
2645  " m.num_cores = 1;\n"
2646  " m.task_type = TASK_LM;\n"
2647  " m.optimizer = OPTIMIZER_SGD;\n"
2648  " m.learning_rate = 0.0f;\n"
2649  " for (int i = 1; i < argc; ++i) {\n"
2650  " if (strcmp(argv[i], \"--dump\") == 0) {\n"
2651  " dump = 1;\n"
2652  " continue;\n"
2653  " }\n"
2654  " if (strcmp(argv[i], \"--dump-all\") == 0) {\n"
2655  " dump = 1;\n"
2656  " dump_all = 1;\n"
2657  " continue;\n"
2658  " }\n"
2659  " if (strcmp(argv[i], \"--no-forward\") == 0) {\n"
2660  " no_forward = 1;\n"
2661  " continue;\n"
2662  " }\n"
2663  " if (strcmp(argv[i], \"--strict\") == 0) {\n"
2664  " strict = 1;\n"
2665  " continue;\n"
2666  " }\n"
2667  " if (strcmp(argv[i], \"--litmus\") == 0) {\n"
2668  " run_litmus = 1;\n"
2669  " continue;\n"
2670  " }\n"
2671  " if (strcmp(argv[i], \"--backward\") == 0) {\n"
2672  " run_backward = 1;\n"
2673  " continue;\n"
2674  " }\n"
2675  " if (strcmp(argv[i], \"--lr\") == 0 && i + 1 < argc) {\n"
2676  " parse_float_arg(argv[++i], &m.learning_rate);\n"
2677  " continue;\n"
2678  " }\n"
2679  " if (strcmp(argv[i], \"--help\") == 0) {\n"
2680  " print_usage(argv[0]);\n"
2681  " return 0;\n"
2682  " }\n"
2683  " if (strcmp(argv[i], \"--hidden\") == 0 && i + 1 < argc) {\n"
2684  " litmus_hidden = argv[++i];\n"
2685  " continue;\n"
2686  " }\n"
2687  " if (strcmp(argv[i], \"--weights\") == 0 && i + 1 < argc) {\n"
2688  " litmus_weights = argv[++i];\n"
2689  " continue;\n"
2690  " }\n"
2691  " if (strcmp(argv[i], \"--targets\") == 0 && i + 1 < argc) {\n"
2692  " litmus_targets = argv[++i];\n"
2693  " continue;\n"
2694  " }\n"
2695  " if (strcmp(argv[i], \"--model-weights\") == 0 && i + 1 < argc) {\n"
2696  " model_weights = argv[++i];\n"
2697  " continue;\n"
2698  " }\n"
2699  " if (strcmp(argv[i], \"--tokens\") == 0 && i + 1 < argc) {\n"
2700  " tokens_path = argv[++i];\n"
2701  " continue;\n"
2702  " }\n"
2703  " if (strcmp(argv[i], \"--out-logits\") == 0 && i + 1 < argc) {\n"
2704  " out_logits = argv[++i];\n"
2705  " continue;\n"
2706  " }\n"
2707  " if (strcmp(argv[i], \"--out-dlogits\") == 0 && i + 1 < argc) {\n"
2708  " out_dlogits = argv[++i];\n"
2709  " continue;\n"
2710  " }\n"
2711  " if (strcmp(argv[i], \"--out-dhidden\") == 0 && i + 1 < argc) {\n"
2712  " out_dhidden = argv[++i];\n"
2713  " continue;\n"
2714  " }\n"
2715  " if (strcmp(argv[i], \"--out-dweights\") == 0 && i + 1 < argc) {\n"
2716  " out_dweights = argv[++i];\n"
2717  " continue;\n"
2718  " }\n"
2719  " if (strcmp(argv[i], \"--out-loss\") == 0 && i + 1 < argc) {\n"
2720  " out_loss = argv[++i];\n"
2721  " continue;\n"
2722  " }\n"
2723  " if (strcmp(argv[i], \"--out-weights\") == 0 && i + 1 < argc) {\n"
2724  " out_weights = argv[++i];\n"
2725  " continue;\n"
2726  " }\n"
2727  " if (strcmp(argv[i], \"--steps\") == 0 && i + 1 < argc) {\n"
2728  " parse_int_arg(argv[++i], &steps);\n"
2729  " continue;\n"
2730  " }\n"
2731  " if (strcmp(argv[i], \"--log-steps\") == 0) {\n"
2732  " log_steps = 1;\n"
2733  " continue;\n"
2734  " }\n"
2735  " if (strcmp(argv[i], \"--layers\") == 0 && i + 1 < argc) {\n"
2736  " parse_int_arg(argv[++i], &m.num_layers);\n"
2737  " continue;\n"
2738  " }\n"
2739  " if (strcmp(argv[i], \"--embed\") == 0 && i + 1 < argc) {\n"
2740  " parse_int_arg(argv[++i], &m.embed_dim);\n"
2741  " continue;\n"
2742  " }\n"
2743  " if (strcmp(argv[i], \"--intermediate\") == 0 && i + 1 < argc) {\n"
2744  " parse_int_arg(argv[++i], &m.intermediate_size);\n"
2745  " continue;\n"
2746  " }\n"
2747  " if (strcmp(argv[i], \"--heads\") == 0 && i + 1 < argc) {\n"
2748  " parse_int_arg(argv[++i], &m.num_attention_heads);\n"
2749  " continue;\n"
2750  " }\n"
2751  " if (strcmp(argv[i], \"--kv-heads\") == 0 && i + 1 < argc) {\n"
2752  " parse_int_arg(argv[++i], &m.num_kv_heads);\n"
2753  " continue;\n"
2754  " }\n"
2755  " if (strcmp(argv[i], \"--vocab\") == 0 && i + 1 < argc) {\n"
2756  " parse_int_arg(argv[++i], &m.vocab_size);\n"
2757  " continue;\n"
2758  " }\n"
2759  " if (strcmp(argv[i], \"--ctx\") == 0 && i + 1 < argc) {\n"
2760  " parse_int_arg(argv[++i], &m.context_window);\n"
2761  " continue;\n"
2762  " }\n"
2763  " if (strcmp(argv[i], \"--cores\") == 0 && i + 1 < argc) {\n"
2764  " parse_int_arg(argv[++i], &m.num_cores);\n"
2765  " continue;\n"
2766  " }\n"
2767  " fprintf(stderr, \"Unknown or invalid arg: %%s\\n\", argv[i]);\n"
2768  " print_usage(argv[0]);\n"
2769  " return 1;\n"
2770  " }\n"
2771  " if (strict) {\n"
2772  " ck_set_strict_parity(1);\n"
2773  " }\n"
2774  " if (run_backward && m.learning_rate == 0.0f) {\n"
2775  " m.learning_rate = 1e-3f;\n"
2776  " }\n"
2777  " m.training_enabled = run_backward;\n"
2778  " m.weight_dtype = CK_DT_FP32;\n"
2779  " {\n"
2780  " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
2781  " if (wd) {\n"
2782  " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
2783  " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
2784  " m.weight_dtype = CK_DT_Q4_K;\n"
2785  " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
2786  " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
2787  " m.weight_dtype = CK_DT_Q6_K;\n"
2788  " }\n"
2789  " }\n"
2790  " }\n"
2791  " init_weight_dtypes_uniform(&m, m.weight_dtype);\n"
2792  " refresh_weight_flags(&m);\n"
2793  " if (model_weights) {\n"
2794  " int dtype_rc = load_weight_dtypes(model_weights, &m);\n"
2795  " if (dtype_rc < 0) {\n"
2796  " fprintf(stderr, \"failed to read weight dtype table\\n\");\n"
2797  " return 1;\n"
2798  " }\n"
2799  " }\n"
2800  " if (m.training_enabled && m.weights_quantized) {\n"
2801  " fprintf(stderr, \"Quantized weights are inference-only; disable training\\n\");\n"
2802  " return 1;\n"
2803  " }\n"
2804  " if (layout_model(&m) != 0) {\n"
2805  " fprintf(stderr, \"layout_model failed\\n\");\n"
2806  " return 1;\n"
2807  " }\n"
2808  " if (model_weights) {\n"
2809  " if (load_model_weights(model_weights, &m) != 0) {\n"
2810  " fprintf(stderr, \"failed to load model weights\\n\");\n"
2811  " return 1;\n"
2812  " }\n"
2813  " }\n"
2814  " if (tokens_path) {\n"
2815  " int T = m.context_window;\n"
2816  " tokens = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2817  " if (!tokens) {\n"
2818  " fprintf(stderr, \"failed to alloc tokens\\n\");\n"
2819  " return 1;\n"
2820  " }\n"
2821  " if (read_ints(tokens_path, tokens, (size_t)T) != 0) {\n"
2822  " fprintf(stderr, \"failed to read tokens\\n\");\n"
2823  " free(tokens);\n"
2824  " tokens = NULL;\n"
2825  " return 1;\n"
2826  " }\n"
2827  " if (!run_backward) {\n"
2828  " embed_tokens(&m, tokens, T);\n"
2829  " free(tokens);\n"
2830  " tokens = NULL;\n"
2831  " }\n"
2832  " }\n"
2833  " if (run_backward) {\n"
2834  " if (!litmus_targets) {\n"
2835  " fprintf(stderr, \"backward requires --targets\\n\");\n"
2836  " return 1;\n"
2837  " }\n"
2838  " int T = m.context_window;\n"
2839  " targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2840  " if (!targets) {\n"
2841  " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2842  " return 1;\n"
2843  " }\n"
2844  " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2845  " fprintf(stderr, \"failed to read targets\\n\");\n"
2846  " free(targets);\n"
2847  " targets = NULL;\n"
2848  " return 1;\n"
2849  " }\n"
2850  " }\n"
2851  " if (dump) {\n"
2852  " dump_layout(&m, dump_all);\n"
2853  " }\n"
2854  " if (run_litmus) {\n"
2855  " if (!litmus_hidden || !litmus_weights || !litmus_targets) {\n"
2856  " fprintf(stderr, \"litmus requires --hidden, --weights, and --targets\\n\");\n"
2857  " return 1;\n"
2858  " }\n"
2859  " int T = m.context_window;\n"
2860  " int V = m.vocab_size;\n"
2861  " int D = m.embed_dim;\n"
2862  " int aligned_D = (int)m.aligned_embed_dim;\n"
2863  " float *hidden = ptr_f32(m.memory_base, m.final_output_offset);\n"
2864  " float *weights = ptr_f32(m.memory_base, m.lm_head_weight_offset);\n"
2865  " float *logits = ptr_f32(m.memory_base, m.logits_offset);\n"
2866  " if (read_floats(litmus_hidden, hidden, (size_t)T * aligned_D) != 0) {\n"
2867  " fprintf(stderr, \"failed to read hidden\\n\");\n"
2868  " return 1;\n"
2869  " }\n"
2870  " if (read_floats(litmus_weights, weights, (size_t)V * aligned_D) != 0) {\n"
2871  " fprintf(stderr, \"failed to read weights\\n\");\n"
2872  " return 1;\n"
2873  " }\n"
2874  " int32_t *targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2875  " if (!targets) {\n"
2876  " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2877  " return 1;\n"
2878  " }\n"
2879  " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2880  " fprintf(stderr, \"failed to read targets\\n\");\n"
2881  " free(targets);\n"
2882  " return 1;\n"
2883  " }\n"
2884  " float *d_logits = (float *)calloc((size_t)T * V, sizeof(float));\n"
2885  " float *d_hidden = (float *)calloc((size_t)T * aligned_D, sizeof(float));\n"
2886  " float *d_weights = (float *)calloc((size_t)V * aligned_D, sizeof(float));\n"
2887  " if (!d_logits || !d_hidden || !d_weights) {\n"
2888  " fprintf(stderr, \"failed to alloc grads\\n\");\n"
2889  " free(targets);\n"
2890  " free(d_logits);\n"
2891  " free(d_hidden);\n"
2892  " free(d_weights);\n"
2893  " return 1;\n"
2894  " }\n"
2895  " lm_head_forward(hidden, weights, logits, T, V, D, aligned_D);\n"
2896  " float loss = 0.0f;\n"
2897  " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2898  " lm_head_backward(hidden, weights, d_logits, d_hidden, d_weights, T, V, D, aligned_D);\n"
2899  " if (out_logits) write_floats(out_logits, logits, (size_t)T * V);\n"
2900  " if (out_dlogits) write_floats(out_dlogits, d_logits, (size_t)T * V);\n"
2901  " if (out_dhidden) write_floats(out_dhidden, d_hidden, (size_t)T * aligned_D);\n"
2902  " if (out_dweights) write_floats(out_dweights, d_weights, (size_t)V * aligned_D);\n"
2903  " if (out_loss) write_float_scalar(out_loss, loss);\n"
2904  " if (!out_loss) printf(\"loss=%%.6f\\n\", loss);\n"
2905  " free(targets);\n"
2906  " free(d_logits);\n"
2907  " free(d_hidden);\n"
2908  " free(d_weights);\n"
2909  " ck_huge_free(m.memory_base, m.total_bytes);\n"
2910  " free(m.layers);\n"
2911  " return 0;\n"
2912  " }\n"
2913  " // TODO: load weights into m.memory_base using the offsets above.\n"
2914  " // TODO: write token/pos embeddings into embedded_input_offset.\n"
2915  " if (!run_backward) {\n"
2916  " if (!no_forward) {\n"
2917  " run_model_forward(&m);\n"
2918  " }\n"
2919  " } else {\n"
2920  " if (!tokens || !targets) {\n"
2921  " fprintf(stderr, \"backward requires --tokens and --targets\\n\");\n"
2922  " return 1;\n"
2923  " }\n"
2924  " if (steps < 1) steps = 1;\n"
2925  " float loss = 0.0f;\n"
2926  " for (int step = 0; step < steps; ++step) {\n"
2927  " embed_tokens(&m, tokens, m.context_window);\n"
2928  " run_model_forward(&m);\n"
2929  " if (run_model_backward(&m, tokens, targets, &loss) != 0) {\n"
2930  " fprintf(stderr, \"backward failed\\n\");\n"
2931  " return 1;\n"
2932  " }\n"
2933  " if (log_steps) {\n"
2934  " printf(\"step %%d loss=%%.6f\\n\", step, loss);\n"
2935  " }\n"
2936  " }\n"
2937  " if (out_loss) {\n"
2938  " write_float_scalar(out_loss, loss);\n"
2939  " }\n"
2940  " }\n"
2941  " if (out_logits) {\n"
2942  " write_floats(out_logits, ptr_f32(m.memory_base, m.logits_offset),\n"
2943  " (size_t)m.context_window * (size_t)m.vocab_size);\n"
2944  " }\n"
2945  " if (out_weights) {\n"
2946  " if (save_model_weights(out_weights, &m) != 0) {\n"
2947  " fprintf(stderr, \"failed to save model weights\\n\");\n"
2948  " return 1;\n"
2949  " }\n"
2950  " }\n"
2951  " ck_huge_free(m.memory_base, m.total_bytes);\n"
2952  " free(m.layers);\n"
2953  " free(tokens);\n"
2954  " free(targets);\n"
2955  " return 0;\n"
2956  "}\n",
2957  forward->config.num_layers,
2958  forward->config.hidden_size,
2959  forward->config.intermediate_size,
2960  forward->config.num_heads,
2961  forward->config.num_kv_heads,
2962  forward->config.vocab_size,
2963  forward->config.context_window,
2964  forward->config.rms_norm_eps,
2965  forward->config.rope_theta);
2966  } else {
2967  /* Library mode - emit API functions instead of main() */
2968  emit_library_api(out, forward);
2969  }
2970 
2971  fclose(out);
2972  if (emit_kernel_manifest(forward, path) != 0) {
2973  return -1;
2974  }
2975  return 0;
2976 }
@ CK_EMIT_STANDALONE
static int emit_runtime_preamble(FILE *out)
static int emit_kernel_manifest(const CKIRGraph *forward, const char *runtime_path)
static const char * ck_first_layer_buffer_name(void)
static void emit_library_api(FILE *out, const CKIRGraph *forward)
static void emit_layer_allocations(FILE *out)
static void emit_sgd_update(FILE *out)
static void emit_model_struct(FILE *out)
static void emit_global_allocations(FILE *out)
static void emit_zero_grad(FILE *out)
static void emit_layer_offsets_struct(FILE *out)
static void emit_global_aliases_to_layer(FILE *out)
int ck_ir_validate_supported(const CKIRGraph *graph)

References CK_EMIT_STANDALONE, ck_first_layer_buffer_name(), ck_ir_validate_supported(), CKIRGraph::config, CKModelConfig::context_window, emit_global_aliases_to_layer(), emit_global_allocations(), emit_kernel_manifest(), emit_layer_allocations(), emit_layer_offsets_struct(), emit_library_api(), emit_model_struct(), emit_runtime_preamble(), emit_sgd_update(), emit_zero_grad(), CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, and CKModelConfig::vocab_size.

◆ ck_find_buffer_spec()

static const CKBufferSpec* ck_find_buffer_spec ( const char *  name)
static

Definition at line 31 of file ckernel_codegen_v6.5.c.

32 {
33  if (!name) {
34  return NULL;
35  }
36  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
37  if (strcmp(ck_decoder_buffers[i].name, name) == 0) {
38  return &ck_decoder_buffers[i];
39  }
40  }
41  return NULL;
42 }
const CKBufferSpec ck_decoder_buffers[]
const size_t ck_decoder_buffer_count

References ck_decoder_buffer_count, and ck_decoder_buffers.

Referenced by emit_global_aliases_to_layer(), emit_global_allocations(), and emit_sgd_update().

◆ ck_find_kernel_spec()

static const CKKernelSpec* ck_find_kernel_spec ( const char *  name)
static

Definition at line 44 of file ckernel_codegen_v6.5.c.

45 {
46  if (!name) {
47  return NULL;
48  }
49  for (size_t i = 0; i < ck_kernel_spec_count; ++i) {
50  if (strcmp(ck_kernel_specs[i].name, name) == 0) {
51  return &ck_kernel_specs[i];
52  }
53  }
54  return NULL;
55 }
const CKKernelSpec ck_kernel_specs[]
const size_t ck_kernel_spec_count

References ck_kernel_spec_count, and ck_kernel_specs.

Referenced by emit_plan_sources().

◆ ck_first_layer_buffer_name()

static const char* ck_first_layer_buffer_name ( void  )
static

Definition at line 598 of file ckernel_codegen_v6.5.c.

599 {
600  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
601  const CKBufferSpec *spec = &ck_decoder_buffers[i];
602  if (spec->scope != CK_SCOPE_LAYER) {
603  continue;
604  }
605  if (!ck_buffer_should_alloc(spec)) {
606  continue;
607  }
608  return spec->name;
609  }
610  return "ln1_gamma";
611 }
static int ck_buffer_should_alloc(const CKBufferSpec *spec)
@ CK_SCOPE_LAYER
CKBufferScope scope

References ck_buffer_should_alloc(), ck_decoder_buffer_count, ck_decoder_buffers, CK_SCOPE_LAYER, CKBufferSpec::name, and CKBufferSpec::scope.

Referenced by ck_codegen_emit_runtime().

◆ ck_plan_step_enabled()

static int ck_plan_step_enabled ( const CKPlanStep step,
const CKIRGraph cfg 
)
static

Definition at line 554 of file ckernel_codegen_v6.5.c.

555 {
556  if (!step || !step->condition || !cfg) {
557  return 1;
558  }
559  if (strcmp(step->condition, "rope_theta") == 0) {
560  return cfg->config.rope_theta > 0.0f;
561  }
562  if (strcmp(step->condition, "rope_theta>0") == 0) {
563  return cfg->config.rope_theta > 0.0f;
564  }
565  return 1;
566 }
const char * condition

References CKPlanStep::condition, CKIRGraph::config, and CKModelConfig::rope_theta.

Referenced by emit_plan_sources().

◆ ck_weight_dtype_expr()

static const char* ck_weight_dtype_expr ( const CKBufferSpec spec)
static

Definition at line 80 of file ckernel_codegen_v6.5.c.

81 {
82  if (!spec || spec->role != CK_ROLE_WEIGHT || !spec->name) {
83  return NULL;
84  }
85  if (spec->scope == CK_SCOPE_GLOBAL) {
86  if (strcmp(spec->name, "token_emb") == 0) {
87  return "m->token_emb_dtype";
88  }
89  if (strcmp(spec->name, "lm_head_weight") == 0) {
90  return "m->lm_head_weight_dtype";
91  }
92  return NULL;
93  }
94  if (spec->scope == CK_SCOPE_LAYER) {
95  if (strcmp(spec->name, "wq") == 0) return "L->wq_dtype";
96  if (strcmp(spec->name, "wk") == 0) return "L->wk_dtype";
97  if (strcmp(spec->name, "wv") == 0) return "L->wv_dtype";
98  if (strcmp(spec->name, "wo") == 0) return "L->wo_dtype";
99  if (strcmp(spec->name, "w1") == 0) return "L->w1_dtype";
100  if (strcmp(spec->name, "w2") == 0) return "L->w2_dtype";
101  return NULL;
102  }
103  return NULL;
104 }
@ CK_SCOPE_GLOBAL

References CK_ROLE_WEIGHT, CK_SCOPE_GLOBAL, CK_SCOPE_LAYER, CKBufferSpec::name, CKBufferSpec::role, and CKBufferSpec::scope.

Referenced by emit_global_allocations(), and emit_layer_allocations().

◆ emit_bump_bytes_assignment()

static void emit_bump_bytes_assignment ( FILE *  out,
const char *  indent,
const char *  struct_prefix,
const char *  name,
const CKDimToken shape 
)
static

Definition at line 275 of file ckernel_codegen_v6.5.c.

280 {
281  fprintf(out, "%s%s%s_offset = bump_bytes(&off, (", indent, struct_prefix, name);
282  emit_shape_expr(out, shape);
283  fprintf(out, ") * elem_bytes, CACHELINE_BYTES);\n");
284 }
static void emit_shape_expr(FILE *out, const CKDimToken *shape)

References emit_shape_expr().

Referenced by emit_global_allocations(), and emit_layer_allocations().

◆ emit_bump_bytes_assignment_weight_dtype()

static void emit_bump_bytes_assignment_weight_dtype ( FILE *  out,
const char *  indent,
const char *  struct_prefix,
const char *  name,
const CKDimToken shape,
const char *  dtype_expr 
)
static

Definition at line 286 of file ckernel_codegen_v6.5.c.

292 {
293  fprintf(out, "%s%s%s_offset = bump_bytes(&off, ck_dtype_row_bytes(%s, (",
294  indent, struct_prefix, name, dtype_expr);
295  emit_shape_expr(out, shape);
296  fprintf(out, ")), CACHELINE_BYTES);\n");
297 }

References emit_shape_expr().

Referenced by emit_global_allocations(), and emit_layer_allocations().

◆ emit_dim_expr()

static void emit_dim_expr ( FILE *  out,
CKDimKind  dim 
)
static

Definition at line 231 of file ckernel_codegen_v6.5.c.

232 {
233  switch (dim) {
234  case CK_DIM_TOKENS: fprintf(out, "(size_t)m->context_window"); break;
235  case CK_DIM_EMBED: fprintf(out, "(size_t)m->embed_dim"); break;
236  case CK_DIM_ALIGNED_EMBED: fprintf(out, "m->aligned_embed_dim"); break;
237  case CK_DIM_HEAD_DIM: fprintf(out, "(size_t)m->head_dim"); break;
238  case CK_DIM_ALIGNED_HEAD: fprintf(out, "m->aligned_head_dim"); break;
239  case CK_DIM_NUM_HEADS: fprintf(out, "(size_t)m->num_attention_heads"); break;
240  case CK_DIM_NUM_KV_HEADS: fprintf(out, "(size_t)m->num_kv_heads"); break;
241  case CK_DIM_ALIGNED_CTX: fprintf(out, "m->aligned_attn_context_window"); break;
242  case CK_DIM_INTERMEDIATE: fprintf(out, "(size_t)m->intermediate_size"); break;
243  case CK_DIM_ALIGNED_INTERMEDIATE:fprintf(out, "aligned_intermediate_dim"); break;
244  case CK_DIM_VOCAB: fprintf(out, "(size_t)m->vocab_size"); break;
245  case CK_DIM_END: fprintf(out, "0"); break;
246  }
247 }
@ CK_DIM_ALIGNED_INTERMEDIATE
@ CK_DIM_NUM_HEADS
@ CK_DIM_ALIGNED_EMBED
@ CK_DIM_TOKENS
@ CK_DIM_INTERMEDIATE
@ CK_DIM_ALIGNED_CTX
@ CK_DIM_END
@ CK_DIM_ALIGNED_HEAD
@ CK_DIM_HEAD_DIM
@ CK_DIM_NUM_KV_HEADS
@ CK_DIM_VOCAB
@ CK_DIM_EMBED

References CK_DIM_ALIGNED_CTX, CK_DIM_ALIGNED_EMBED, CK_DIM_ALIGNED_HEAD, CK_DIM_ALIGNED_INTERMEDIATE, CK_DIM_EMBED, CK_DIM_END, CK_DIM_HEAD_DIM, CK_DIM_INTERMEDIATE, CK_DIM_NUM_HEADS, CK_DIM_NUM_KV_HEADS, CK_DIM_TOKENS, and CK_DIM_VOCAB.

Referenced by emit_shape_expr().

◆ emit_global_aliases_to_layer()

static void emit_global_aliases_to_layer ( FILE *  out)
static

Definition at line 398 of file ckernel_codegen_v6.5.c.

399 {
400  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
401  const CKBufferSpec *spec = &ck_decoder_buffers[i];
402  if (spec->scope != CK_SCOPE_GLOBAL || !spec->alias_of) {
403  continue;
404  }
405  const CKBufferSpec *alias = ck_find_buffer_spec(spec->alias_of);
406  if (!alias || alias->scope != CK_SCOPE_LAYER) {
407  continue;
408  }
409  fprintf(out,
410  " if (m->num_layers > 0) {\n"
411  " m->%s_offset = m->layers[m->num_layers - 1].%s_offset;\n"
412  " } else {\n"
413  " m->%s_offset = 0;\n"
414  " }\n",
415  spec->name, spec->alias_of, spec->name);
416  }
417 }
static const CKBufferSpec * ck_find_buffer_spec(const char *name)
const char * alias_of

References CKBufferSpec::alias_of, ck_decoder_buffer_count, ck_decoder_buffers, ck_find_buffer_spec(), CK_SCOPE_GLOBAL, CK_SCOPE_LAYER, CKBufferSpec::name, and CKBufferSpec::scope.

Referenced by ck_codegen_emit_runtime().

◆ emit_global_allocations()

static void emit_global_allocations ( FILE *  out)
static

Definition at line 311 of file ckernel_codegen_v6.5.c.

312 {
313  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
314  const CKBufferSpec *spec = &ck_decoder_buffers[i];
315  if (spec->scope != CK_SCOPE_GLOBAL) {
316  continue;
317  }
318  if (!ck_buffer_should_alloc(spec)) {
319  continue;
320  }
321  if (spec->role == CK_ROLE_GRAD) {
322  emit_training_conditional_assignment(out, " ", "m->", spec->name, spec->shape);
323  continue;
324  }
325  if (spec->alias_of) {
326  const CKBufferSpec *alias = ck_find_buffer_spec(spec->alias_of);
327  if (alias && alias->scope == CK_SCOPE_GLOBAL) {
328  fprintf(out, " m->%s_offset = m->%s_offset;\n", spec->name, spec->alias_of);
329  }
330  continue;
331  }
332  if (spec->condition && strcmp(spec->condition, "rope_theta") == 0) {
333  fprintf(out, " if (m->rope_theta > 0.0f) {\n");
334  fprintf(out, " m->%s_offset = bump_bytes(&off, (", spec->name);
335  emit_shape_expr(out, spec->shape);
336  fprintf(out, ") * elem_bytes, CACHELINE_BYTES);\n");
337  fprintf(out, " } else {\n");
338  fprintf(out, " m->%s_offset = 0;\n", spec->name);
339  fprintf(out, " }\n");
340  continue;
341  }
342  if (spec->condition && strcmp(spec->condition, "training_enabled") == 0) {
343  fprintf(out, " if (m->training_enabled) {\n");
344  fprintf(out, " m->%s_offset = bump_bytes(&off, (", spec->name);
345  emit_shape_expr(out, spec->shape);
346  fprintf(out, ") * elem_bytes, CACHELINE_BYTES);\n");
347  fprintf(out, " } else {\n");
348  fprintf(out, " m->%s_offset = 0;\n", spec->name);
349  fprintf(out, " }\n");
350  continue;
351  }
352  if (ck_buffer_uses_weight_dtype(spec)) {
353  const char *dtype_expr = ck_weight_dtype_expr(spec);
354  if (dtype_expr) {
355  emit_bump_bytes_assignment_weight_dtype(out, " ", "m->", spec->name, spec->shape, dtype_expr);
356  continue;
357  }
358  }
359  emit_bump_bytes_assignment(out, " ", "m->", spec->name, spec->shape);
360  }
361 }
static int ck_buffer_uses_weight_dtype(const CKBufferSpec *spec)
static const char * ck_weight_dtype_expr(const CKBufferSpec *spec)
static void emit_bump_bytes_assignment(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
static void emit_training_conditional_assignment(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape)
static void emit_bump_bytes_assignment_weight_dtype(FILE *out, const char *indent, const char *struct_prefix, const char *name, const CKDimToken *shape, const char *dtype_expr)
@ CK_ROLE_GRAD
const char * condition
CKDimToken shape[4]

References CKBufferSpec::alias_of, ck_buffer_should_alloc(), ck_buffer_uses_weight_dtype(), ck_decoder_buffer_count, ck_decoder_buffers, ck_find_buffer_spec(), CK_ROLE_GRAD, CK_SCOPE_GLOBAL, ck_weight_dtype_expr(), CKBufferSpec::condition, emit_bump_bytes_assignment(), emit_bump_bytes_assignment_weight_dtype(), emit_shape_expr(), emit_training_conditional_assignment(), CKBufferSpec::name, CKBufferSpec::role, CKBufferSpec::scope, and CKBufferSpec::shape.

Referenced by ck_codegen_emit_runtime().

◆ emit_global_offset_fields()

static void emit_global_offset_fields ( FILE *  out)
static

Definition at line 134 of file ckernel_codegen_v6.5.c.

135 {
136  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
137  const CKBufferSpec *spec = &ck_decoder_buffers[i];
138  if (spec->scope != CK_SCOPE_GLOBAL) {
139  continue;
140  }
141  if (!ck_buffer_should_alloc(spec)) {
142  continue;
143  }
144  emit_offset_field(out, spec->name);
145  }
146 }
static void emit_offset_field(FILE *out, const char *name)

References ck_buffer_should_alloc(), ck_decoder_buffer_count, ck_decoder_buffers, CK_SCOPE_GLOBAL, emit_offset_field(), CKBufferSpec::name, and CKBufferSpec::scope.

Referenced by emit_model_struct().

◆ emit_kernel_manifest()

static int emit_kernel_manifest ( const CKIRGraph forward,
const char *  runtime_path 
)
static

Definition at line 820 of file ckernel_codegen_v6.5.c.

821 {
822  if (!forward || !runtime_path) {
823  return -1;
824  }
825 
826  const char *suffix = ".kernels";
827  size_t len = strlen(runtime_path) + strlen(suffix) + 1;
828  char *path = (char *)malloc(len);
829  if (!path) {
830  return -1;
831  }
832  snprintf(path, len, "%s%s", runtime_path, suffix);
833 
834  FILE *f = fopen(path, "wb");
835  if (!f) {
836  fprintf(stderr, "ck_codegen_emit_runtime: failed to open %s: %s\n",
837  path, strerror(errno));
838  free(path);
839  return -1;
840  }
841 
842  size_t seen_cap = ck_kernel_spec_count * CKERNEL_MAX_KERNEL_SOURCES + 8;
843  const char **seen = (const char **)calloc(seen_cap, sizeof(char *));
844  if (!seen) {
845  fclose(f);
846  free(path);
847  return -1;
848  }
849  size_t seen_count = 0;
850 
851  emit_unique_source(f, "src/ckernel_alloc.c", seen, &seen_count, seen_cap);
852  emit_unique_source(f, "src/ckernel_strict.c", seen, &seen_count, seen_cap);
853  emit_unique_source(f, "src/cpu_features.c", seen, &seen_count, seen_cap);
854  emit_unique_source(f, "src/kernels/embedding_kernels.c", seen, &seen_count, seen_cap);
855  /* Quantized inference support (K-quants: Q4_K_M / Q4_K / Q6_K). */
856  emit_unique_source(f, "src/kernels/dequant_kernels.c", seen, &seen_count, seen_cap);
857  emit_unique_source(f, "src/kernels/gemm_kernels_q4k.c", seen, &seen_count, seen_cap);
858  emit_unique_source(f, "src/kernels/gemm_kernels_q4k_q8k.c", seen, &seen_count, seen_cap);
859  emit_unique_source(f, "src/kernels/gemm_kernels_q6k.c", seen, &seen_count, seen_cap);
860  /* Legacy quant support (Q4_0 / Q4_1 / Q5_0 / Q5_1 / Q8_0). */
861  emit_unique_source(f, "src/kernels/gemm_kernels_q4_0.c", seen, &seen_count, seen_cap);
862  emit_unique_source(f, "src/kernels/gemm_kernels_q4_1.c", seen, &seen_count, seen_cap);
863  emit_unique_source(f, "src/kernels/gemm_kernels_q5_0.c", seen, &seen_count, seen_cap);
864  emit_unique_source(f, "src/kernels/gemm_kernels_q5_1.c", seen, &seen_count, seen_cap);
865  emit_unique_source(f, "src/kernels/gemm_kernels_q8_0.c", seen, &seen_count, seen_cap);
866  /* SSE/AVX2/VNNI fallback implementations for quantized kernels. */
867  emit_unique_source(f, "src/kernels/gemm_kernels_q4k_sse.c", seen, &seen_count, seen_cap);
868  emit_unique_source(f, "src/kernels/gemm_kernels_q4k_q8k_avx2.c", seen, &seen_count, seen_cap);
869  emit_unique_source(f, "src/kernels/gemm_kernels_q4k_q8k_vnni.c", seen, &seen_count, seen_cap);
870  emit_unique_source(f, "src/kernels/gemm_kernels_q5_0_sse.c", seen, &seen_count, seen_cap);
871  emit_unique_source(f, "src/kernels/gemm_kernels_q5_0_sse_v2.c", seen, &seen_count, seen_cap);
872  emit_unique_source(f, "src/kernels/gemm_kernels_q6k_sse.c", seen, &seen_count, seen_cap);
873  emit_unique_source(f, "src/kernels/quantize_row_q8_k_sse.c", seen, &seen_count, seen_cap);
874  emit_unique_source(f, "src/kernels/rope_kernels.c", seen, &seen_count, seen_cap);
875  emit_unique_source(f, "src/kernels/loss_kernels.c", seen, &seen_count, seen_cap);
876  emit_unique_source(f, "src/kernels/kv_cache_kernels.c", seen, &seen_count, seen_cap);
877  /* Fused kernels used by orchestration layer. */
878  emit_unique_source(f, "src/kernels/gemm_fused_kernels.c", seen, &seen_count, seen_cap);
879  emit_unique_source(f, "src/kernels/mlp_fused_decode.c", seen, &seen_count, seen_cap);
880  emit_unique_source(f, "src/kernels/gemm_microkernel.c", seen, &seen_count, seen_cap);
881  emit_unique_source(f, "src/kernels/attention_decode_fused.c", seen, &seen_count, seen_cap);
882  emit_unique_source(f, "src/kernels/attention_flash_true.c", seen, &seen_count, seen_cap);
883  if (emit_plan_sources(f,
886  forward,
887  seen,
888  &seen_count,
889  seen_cap) != 0) {
890  free(seen);
891  fclose(f);
892  free(path);
893  return -1;
894  }
895  if (emit_plan_sources(f,
898  forward,
899  seen,
900  &seen_count,
901  seen_cap) != 0) {
902  free(seen);
903  fclose(f);
904  free(path);
905  return -1;
906  }
907  free(seen);
908 
909  fclose(f);
910  fprintf(stderr, "[ck_codegen] kernels manifest written to %s\n", path);
911  free(path);
912  return 0;
913 }
static int emit_plan_sources(FILE *f, const CKPlanStep *plan, size_t plan_count, const CKIRGraph *cfg, const char **seen, size_t *seen_count, size_t seen_cap)
static int emit_unique_source(FILE *f, const char *path, const char **seen, size_t *seen_count, size_t seen_cap)
#define CKERNEL_MAX_KERNEL_SOURCES
const CKPlanStep ck_decoder_forward_plan[]
const size_t ck_decoder_backward_plan_count
const size_t ck_decoder_forward_plan_count
const CKPlanStep ck_decoder_backward_plan[]

References ck_decoder_backward_plan, ck_decoder_backward_plan_count, ck_decoder_forward_plan, ck_decoder_forward_plan_count, ck_kernel_spec_count, CKERNEL_MAX_KERNEL_SOURCES, emit_plan_sources(), and emit_unique_source().

Referenced by ck_codegen_emit_runtime().

◆ emit_layer_allocations()

static void emit_layer_allocations ( FILE *  out)
static

Definition at line 363 of file ckernel_codegen_v6.5.c.

364 {
365  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
366  const CKBufferSpec *spec = &ck_decoder_buffers[i];
367  if (spec->scope != CK_SCOPE_LAYER) {
368  continue;
369  }
370  if (!ck_buffer_should_alloc(spec)) {
371  continue;
372  }
373  if (spec->role == CK_ROLE_GRAD) {
374  emit_training_conditional_assignment(out, " ", "L->", spec->name, spec->shape);
375  continue;
376  }
377  if (spec->condition && strcmp(spec->condition, "training_enabled") == 0) {
378  fprintf(out, " if (m->training_enabled) {\n");
379  fprintf(out, " L->%s_offset = bump_bytes(&off, (", spec->name);
380  emit_shape_expr(out, spec->shape);
381  fprintf(out, ") * elem_bytes, CACHELINE_BYTES);\n");
382  fprintf(out, " } else {\n");
383  fprintf(out, " L->%s_offset = 0;\n", spec->name);
384  fprintf(out, " }\n");
385  continue;
386  }
387  if (ck_buffer_uses_weight_dtype(spec)) {
388  const char *dtype_expr = ck_weight_dtype_expr(spec);
389  if (dtype_expr) {
390  emit_bump_bytes_assignment_weight_dtype(out, " ", "L->", spec->name, spec->shape, dtype_expr);
391  continue;
392  }
393  }
394  emit_bump_bytes_assignment(out, " ", "L->", spec->name, spec->shape);
395  }
396 }

References ck_buffer_should_alloc(), ck_buffer_uses_weight_dtype(), ck_decoder_buffer_count, ck_decoder_buffers, CK_ROLE_GRAD, CK_SCOPE_LAYER, ck_weight_dtype_expr(), CKBufferSpec::condition, emit_bump_bytes_assignment(), emit_bump_bytes_assignment_weight_dtype(), emit_shape_expr(), emit_training_conditional_assignment(), CKBufferSpec::name, CKBufferSpec::role, CKBufferSpec::scope, and CKBufferSpec::shape.

Referenced by ck_codegen_emit_runtime().

◆ emit_layer_offsets_struct()

static void emit_layer_offsets_struct ( FILE *  out)
static

Definition at line 111 of file ckernel_codegen_v6.5.c.

112 {
113  fprintf(out, "typedef struct {\n");
114  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
115  const CKBufferSpec *spec = &ck_decoder_buffers[i];
116  if (spec->scope != CK_SCOPE_LAYER) {
117  continue;
118  }
119  if (!ck_buffer_should_alloc(spec)) {
120  continue;
121  }
122  emit_offset_field(out, spec->name);
123  }
124  fprintf(out,
125  " CKDataType wq_dtype;\n"
126  " CKDataType wk_dtype;\n"
127  " CKDataType wv_dtype;\n"
128  " CKDataType wo_dtype;\n"
129  " CKDataType w1_dtype;\n"
130  " CKDataType w2_dtype;\n");
131  fprintf(out, "} LayerOffsets;\n\n");
132 }

References ck_buffer_should_alloc(), ck_decoder_buffer_count, ck_decoder_buffers, CK_SCOPE_LAYER, emit_offset_field(), CKBufferSpec::name, and CKBufferSpec::scope.

Referenced by ck_codegen_emit_runtime().

◆ emit_library_api()

static void emit_library_api ( FILE *  out,
const CKIRGraph forward 
)
static

Definition at line 916 of file ckernel_codegen_v6.5.c.

917 {
918  fprintf(out,
919  "\n/* ═══════════════════════════════════════════════════════════════\n"
920  " * C-Kernel-Engine Library API (for dlopen)\n"
921  " * ═══════════════════════════════════════════════════════════════ */\n\n"
922  "#ifdef _WIN32\n"
923  "#define CK_EXPORT __declspec(dllexport)\n"
924  "#else\n"
925  "#define CK_EXPORT __attribute__((visibility(\"default\")))\n"
926  "#endif\n\n"
927  "typedef struct {\n"
928  " int num_layers;\n"
929  " int hidden_size;\n"
930  " int intermediate_size;\n"
931  " int num_heads;\n"
932  " int num_kv_heads;\n"
933  " int vocab_size;\n"
934  " int context_window;\n"
935  " float rms_norm_eps;\n"
936  " float rope_theta;\n"
937  "} CKModelInfo;\n\n"
938  "static TransformerModel g_model = {0};\n"
939  "static int g_initialized = 0;\n"
940  "static int g_fuse_swiglu_decode = -2;\n"
941  "static int g_fuse_attn_decode = -2;\n\n"
942  "static int ck_fuse_swiglu_decode_mode(void)\n"
943  "{\n"
944  " if (g_fuse_swiglu_decode != -2) return g_fuse_swiglu_decode;\n"
945  " const char *env = getenv(\"CK_FUSE_SWIGLU_DECODE\");\n"
946  " if (!env || !env[0]) {\n"
947  " g_fuse_swiglu_decode = -1; /* auto */\n"
948  " return g_fuse_swiglu_decode;\n"
949  " }\n"
950  " if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' || env[0] == 'f' || env[0] == 'F') {\n"
951  " g_fuse_swiglu_decode = 0;\n"
952  " } else {\n"
953  " g_fuse_swiglu_decode = 1;\n"
954  " }\n"
955  " return g_fuse_swiglu_decode;\n"
956  "}\n\n");
957 
958  fprintf(out,
959  "static int ck_fuse_attn_decode_mode(void)\n"
960  "{\n"
961  " if (g_fuse_attn_decode != -2) return g_fuse_attn_decode;\n"
962  " const char *env = getenv(\"CK_FUSE_ATTN_DECODE\");\n"
963  " if (!env || !env[0]) {\n"
964  " g_fuse_attn_decode = -1; /* auto */\n"
965  " return g_fuse_attn_decode;\n"
966  " }\n"
967  " if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' || env[0] == 'f' || env[0] == 'F') {\n"
968  " g_fuse_attn_decode = 0;\n"
969  " } else {\n"
970  " g_fuse_attn_decode = 1;\n"
971  " }\n"
972  " return g_fuse_attn_decode;\n"
973  "}\n\n");
974 
975  fprintf(out,
976  "static int run_model_decode(TransformerModel *m, int32_t token)\n"
977  "{\n"
978  " if (!m || !m->memory_base) return -1;\n"
979  " /* KV-cache decode is an inference-only fast path; training uses the full forward/backward graph. */\n"
980  " if (m->training_enabled) return -4;\n"
981  " if (!m->kv_cache_enabled) return -2;\n"
982  "\n"
983  " int cache_cap = m->kv_cache_capacity > 0 ? m->kv_cache_capacity : m->context_window;\n"
984  " if (cache_cap > m->context_window) cache_cap = m->context_window;\n"
985  " int t = m->kv_cache_tokens;\n"
986  " if (t < 0) t = 0;\n"
987  " if (t >= cache_cap) return -3;\n"
988  "\n"
989  " embed_token_at(m, token, t);\n"
990  "\n"
991  " uint8_t *base = m->memory_base;\n"
992  " float *current = ptr_f32(base, m->embedded_input_offset);\n"
993  " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
994  " int fuse_swiglu_mode = ck_fuse_swiglu_decode_mode();\n"
995  " int use_fused_swiglu = 0;\n"
996  " if (fuse_swiglu_mode > 0) {\n"
997  " use_fused_swiglu = 1;\n"
998  " } else if (fuse_swiglu_mode == 0) {\n"
999  " use_fused_swiglu = 0;\n"
1000  " } else if (!ck_strict_parity_enabled()) {\n"
1001  " use_fused_swiglu = 1;\n"
1002  " }\n"
1003  " int fuse_attn_mode = ck_fuse_attn_decode_mode();\n"
1004  " int use_fused_attn = 0;\n"
1005  " if (fuse_attn_mode > 0) {\n"
1006  " use_fused_attn = 1;\n"
1007  " } else if (fuse_attn_mode == 0) {\n"
1008  " use_fused_attn = 0;\n"
1009  " } else if (!ck_strict_parity_enabled()) {\n"
1010  " use_fused_attn = 1;\n"
1011  " }\n"
1012  " if (m->weights_quantized) {\n"
1013  " use_fused_swiglu = 0;\n"
1014  " use_fused_attn = 0;\n"
1015  " }\n"
1016  "\n"
1017  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1018  " TrulyOptimalLayer *L = &m->layers[layer];\n"
1019  " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1020  " CKLayerForwardParamsQ4K p = {0};\n"
1021  " p.tokens = cache_cap;\n"
1022  " p.embed_dim = m->embed_dim;\n"
1023  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1024  " p.num_heads = m->num_attention_heads;\n"
1025  " p.num_kv_heads = m->num_kv_heads;\n"
1026  " p.head_dim = m->head_dim;\n"
1027  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1028  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1029  " p.intermediate_dim = m->intermediate_size;\n"
1030  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1031  " p.eps = m->rms_norm_eps;\n"
1032  " p.rope_pos_offset = t;\n"
1033  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1034  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1035  " p.input = current;\n"
1036  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1037  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1038  " p.wq = cptr_void(base, L->wq_offset);\n"
1039  " p.bq = cptr_f32(base, L->bq_offset);\n"
1040  " p.wk = cptr_void(base, L->wk_offset);\n"
1041  " p.bk = cptr_f32(base, L->bk_offset);\n"
1042  " p.wv = cptr_void(base, L->wv_offset);\n"
1043  " p.bv = cptr_f32(base, L->bv_offset);\n"
1044  " p.wo = cptr_void(base, L->wo_offset);\n"
1045  " p.bo = cptr_f32(base, L->bo_offset);\n"
1046  " p.w1 = cptr_void(base, L->w1_offset);\n"
1047  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1048  " p.w2 = cptr_void(base, L->w2_offset);\n"
1049  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1050  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1051  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1052  " p.k = ptr_f32(base, L->k_offset);\n"
1053  " p.v = ptr_f32(base, L->v_offset);\n"
1054  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1055  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1056  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1057  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1058  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1059  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1060  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1061  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1062  " p.output = ptr_f32(base, L->output_offset);\n"
1063  "\n"
1064  " ck_layer_forward_rmsnorm_swiglu_decode_q4_k(&p, t, cache_cap);\n"
1065  " } else if (m->weights_quantized) {\n"
1066  " CKLayerForwardParamsQ4K p = {0};\n"
1067  " p.tokens = cache_cap;\n"
1068  " p.embed_dim = m->embed_dim;\n"
1069  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1070  " p.num_heads = m->num_attention_heads;\n"
1071  " p.num_kv_heads = m->num_kv_heads;\n"
1072  " p.head_dim = m->head_dim;\n"
1073  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1074  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1075  " p.intermediate_dim = m->intermediate_size;\n"
1076  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1077  " p.eps = m->rms_norm_eps;\n"
1078  " p.rope_pos_offset = t;\n"
1079  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1080  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1081  " p.input = current;\n"
1082  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1083  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1084  " p.wq = cptr_void(base, L->wq_offset);\n"
1085  " p.bq = cptr_f32(base, L->bq_offset);\n"
1086  " p.wk = cptr_void(base, L->wk_offset);\n"
1087  " p.bk = cptr_f32(base, L->bk_offset);\n"
1088  " p.wv = cptr_void(base, L->wv_offset);\n"
1089  " p.bv = cptr_f32(base, L->bv_offset);\n"
1090  " p.wo = cptr_void(base, L->wo_offset);\n"
1091  " p.bo = cptr_f32(base, L->bo_offset);\n"
1092  " p.w1 = cptr_void(base, L->w1_offset);\n"
1093  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1094  " p.w2 = cptr_void(base, L->w2_offset);\n"
1095  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1096  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1097  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1098  " p.k = ptr_f32(base, L->k_offset);\n"
1099  " p.v = ptr_f32(base, L->v_offset);\n"
1100  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1101  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1102  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1103  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1104  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1105  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1106  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1107  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1108  " p.output = ptr_f32(base, L->output_offset);\n"
1109  " p.wq_dtype = L->wq_dtype;\n"
1110  " p.wk_dtype = L->wk_dtype;\n"
1111  " p.wv_dtype = L->wv_dtype;\n"
1112  " p.wo_dtype = L->wo_dtype;\n"
1113  " p.w1_dtype = L->w1_dtype;\n"
1114  " p.w2_dtype = L->w2_dtype;\n"
1115  "\n"
1116  " ck_layer_forward_rmsnorm_swiglu_decode_quant(&p, t, cache_cap);\n"
1117  " } else {\n"
1118  " CKLayerForwardParams p = {0};\n"
1119  " p.tokens = cache_cap;\n"
1120  " p.embed_dim = m->embed_dim;\n"
1121  " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1122  " p.num_heads = m->num_attention_heads;\n"
1123  " p.num_kv_heads = m->num_kv_heads;\n"
1124  " p.head_dim = m->head_dim;\n"
1125  " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1126  " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1127  " p.intermediate_dim = m->intermediate_size;\n"
1128  " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1129  " p.eps = m->rms_norm_eps;\n"
1130  " p.rope_pos_offset = t;\n"
1131  " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1132  " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1133  " p.input = current;\n"
1134  " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1135  " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1136  " p.wq = cptr_f32(base, L->wq_offset);\n"
1137  " p.bq = cptr_f32(base, L->bq_offset);\n"
1138  " p.wk = cptr_f32(base, L->wk_offset);\n"
1139  " p.bk = cptr_f32(base, L->bk_offset);\n"
1140  " p.wv = cptr_f32(base, L->wv_offset);\n"
1141  " p.bv = cptr_f32(base, L->bv_offset);\n"
1142  " p.wo = cptr_f32(base, L->wo_offset);\n"
1143  " p.bo = cptr_f32(base, L->bo_offset);\n"
1144  " p.w1 = cptr_f32(base, L->w1_offset);\n"
1145  " p.b1 = cptr_f32(base, L->b1_offset);\n"
1146  " p.w2 = cptr_f32(base, L->w2_offset);\n"
1147  " p.b2 = cptr_f32(base, L->b2_offset);\n"
1148  " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1149  " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1150  " p.k = ptr_f32(base, L->k_offset);\n"
1151  " p.v = ptr_f32(base, L->v_offset);\n"
1152  " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1153  " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1154  " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1155  " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1156  " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1157  " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1158  " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1159  " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1160  " p.output = ptr_f32(base, L->output_offset);\n"
1161  "\n"
1162  " if (use_fused_attn) {\n"
1163  " if (use_fused_swiglu) {\n"
1164  " ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(&p, t, cache_cap);\n"
1165  " } else {\n"
1166  " ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(&p, t, cache_cap);\n"
1167  " }\n"
1168  " } else if (use_fused_swiglu) {\n"
1169  " ck_layer_forward_rmsnorm_swiglu_decode_fused(&p, t, cache_cap);\n"
1170  " } else {\n"
1171  " ck_layer_forward_rmsnorm_swiglu_decode(&p, t, cache_cap);\n"
1172  " }\n"
1173  " }\n"
1174  " current = ptr_f32(base, L->output_offset);\n"
1175  " }\n"
1176  "\n"
1177  " int V = m->vocab_size;\n"
1178  " int D = m->embed_dim;\n"
1179  " int aligned_D = (int)m->aligned_embed_dim;\n"
1180  " float *final_in = current + (size_t)t * aligned_D;\n"
1181  " float *final_out = ptr_f32(base, m->final_output_offset) + (size_t)t * aligned_D;\n"
1182  " float *final_rstd = ptr_f32(base, m->final_ln_rstd_offset) + (size_t)t;\n"
1183  "\n"
1184  " rmsnorm_forward(final_in,\n"
1185  " cptr_f32(base, m->final_ln_weight_offset),\n"
1186  " final_out,\n"
1187  " final_rstd,\n"
1188  " 1,\n"
1189  " D,\n"
1190  " aligned_D,\n"
1191  " m->rms_norm_eps);\n"
1192  " if (V > 0) {\n"
1193  " float *logits_row = ptr_f32(base, m->logits_offset) + (size_t)t * (size_t)V;\n"
1194  " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1195  " gemm_nt_q4_k(final_out,\n"
1196  " cptr_void(base, m->lm_head_weight_offset),\n"
1197  " NULL,\n"
1198  " logits_row,\n"
1199  " 1,\n"
1200  " V,\n"
1201  " aligned_D);\n"
1202  " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1203  " gemm_nt_q6_k(final_out,\n"
1204  " cptr_void(base, m->lm_head_weight_offset),\n"
1205  " NULL,\n"
1206  " logits_row,\n"
1207  " 1,\n"
1208  " V,\n"
1209  " aligned_D);\n"
1210  " } else {\n"
1211  " lm_head_forward(final_out,\n"
1212  " cptr_f32(base, m->lm_head_weight_offset),\n"
1213  " logits_row,\n"
1214  " 1,\n"
1215  " V,\n"
1216  " D,\n"
1217  " aligned_D);\n"
1218  " }\n"
1219  " }\n"
1220  "\n"
1221  " m->kv_cache_tokens = t + 1;\n"
1222  " m->active_tokens = m->kv_cache_tokens;\n"
1223  " return 0;\n"
1224  "}\n\n");
1225 
1226  /* ck_model_init */
1227  fprintf(out,
1228  "CK_EXPORT int ck_model_init(const char *weights_path)\n"
1229  "{\n"
1230  " if (g_initialized) return 0;\n"
1231  " memcpy(g_model.magic, \"BUMPWGT3\", 8);\n"
1232  " g_model.version = 3;\n"
1233  " g_model.model_type = 0;\n"
1234  " g_model.num_layers = %d;\n"
1235  " g_model.embed_dim = %d;\n"
1236  " g_model.intermediate_size = %d;\n"
1237  " g_model.num_attention_heads = %d;\n"
1238  " g_model.num_kv_heads = %d;\n"
1239  " g_model.vocab_size = %d;\n"
1240  " g_model.context_window = %d;\n"
1241  " g_model.rms_norm_eps = (float)%.9g;\n"
1242  " g_model.rope_theta = (float)%.9g;\n"
1243  " g_model.num_cores = 1;\n"
1244  " g_model.task_type = TASK_LM;\n"
1245  " g_model.weight_dtype = CK_DT_FP32;\n"
1246  " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
1247  " if (wd) {\n"
1248  " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
1249  " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
1250  " g_model.weight_dtype = CK_DT_Q4_K;\n"
1251  " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
1252  " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
1253  " g_model.weight_dtype = CK_DT_Q6_K;\n"
1254  " }\n"
1255  " }\n"
1256  " init_weight_dtypes_uniform(&g_model, g_model.weight_dtype);\n"
1257  " refresh_weight_flags(&g_model);\n"
1258  " /* Check env var to pre-allocate gradient buffers for training */\n"
1259  " const char *train_env = getenv(\"CK_ENABLE_TRAINING\");\n"
1260  " if (train_env && (train_env[0] == '1' || train_env[0] == 'y' || train_env[0] == 'Y')) {\n"
1261  " g_model.training_enabled = true;\n"
1262  " g_model.learning_rate = 1e-4f;\n"
1263  " }\n"
1264  " if (weights_path) {\n"
1265  " int dtype_rc = load_weight_dtypes(weights_path, &g_model);\n"
1266  " if (dtype_rc < 0) {\n"
1267  " fprintf(stderr, \"Failed to read weight dtype table from %%s\\n\", weights_path);\n"
1268  " return -6;\n"
1269  " }\n"
1270  " }\n"
1271  " if (g_model.training_enabled && g_model.weights_quantized) {\n"
1272  " fprintf(stderr, \"Quantized weights are inference-only; disable training for this model\\n\");\n"
1273  " return -5;\n"
1274  " }\n"
1275  " g_model.kv_cache_enabled = false;\n"
1276  " g_model.kv_cache_capacity = g_model.context_window;\n"
1277  " g_model.kv_cache_tokens = 0;\n"
1278  " if (layout_model(&g_model) != 0) return -1;\n"
1279  " if (weights_path) {\n"
1280  " if (load_model_weights(weights_path, &g_model) != 0) return -2;\n"
1281  " }\n"
1282  " g_initialized = 1;\n"
1283  " return 0;\n"
1284  "}\n\n",
1285  forward->config.num_layers,
1286  forward->config.hidden_size,
1287  forward->config.intermediate_size,
1288  forward->config.num_heads,
1289  forward->config.num_kv_heads,
1290  forward->config.vocab_size,
1291  forward->config.context_window,
1292  forward->config.rms_norm_eps,
1293  forward->config.rope_theta);
1294 
1295  /* ck_model_get_info */
1296  fprintf(out,
1297  "CK_EXPORT void ck_model_get_info(CKModelInfo *info)\n"
1298  "{\n"
1299  " if (!info) return;\n"
1300  " info->num_layers = g_model.num_layers;\n"
1301  " info->hidden_size = g_model.embed_dim;\n"
1302  " info->intermediate_size = g_model.intermediate_size;\n"
1303  " info->num_heads = g_model.num_attention_heads;\n"
1304  " info->num_kv_heads = g_model.num_kv_heads;\n"
1305  " info->vocab_size = g_model.vocab_size;\n"
1306  " info->context_window = g_model.context_window;\n"
1307  " info->rms_norm_eps = g_model.rms_norm_eps;\n"
1308  " info->rope_theta = g_model.rope_theta;\n"
1309  "}\n\n");
1310 
1311  /* ck_model_embed_tokens */
1312  fprintf(out,
1313  "CK_EXPORT int ck_model_embed_tokens(const int32_t *tokens, int num_tokens)\n"
1314  "{\n"
1315  " if (!g_initialized) return -1;\n"
1316  " int cap = g_model.context_window;\n"
1317  " if (g_model.kv_cache_enabled && g_model.kv_cache_capacity > 0 && g_model.kv_cache_capacity < cap) {\n"
1318  " cap = g_model.kv_cache_capacity;\n"
1319  " }\n"
1320  " if (num_tokens > cap) num_tokens = cap;\n"
1321  " if (num_tokens < 1) num_tokens = 1;\n"
1322  " g_model.active_tokens = num_tokens;\n"
1323  " if (g_model.kv_cache_enabled && !g_model.training_enabled) {\n"
1324  " g_model.kv_cache_tokens = 0;\n"
1325  " }\n"
1326  " embed_tokens(&g_model, tokens, num_tokens);\n"
1327  " return 0;\n"
1328  "}\n\n");
1329 
1330  /* ck_model_forward */
1331  fprintf(out,
1332  "CK_EXPORT int ck_model_forward(float *logits_out)\n"
1333  "{\n"
1334  " if (!g_initialized) return -1;\n"
1335  " run_model_forward(&g_model);\n"
1336  " if (g_model.kv_cache_enabled && !g_model.training_enabled) {\n"
1337  " g_model.kv_cache_tokens = g_model.active_tokens;\n"
1338  " }\n"
1339  " if (logits_out && g_model.vocab_size > 0) {\n"
1340  " size_t n = (size_t)g_model.active_tokens * (size_t)g_model.vocab_size;\n"
1341  " memcpy(logits_out, ptr_f32(g_model.memory_base, g_model.logits_offset), n * sizeof(float));\n"
1342  " }\n"
1343  " return 0;\n"
1344  "}\n\n");
1345 
1346  /* KV-cache helpers + decode API */
1347  fprintf(out,
1348  "CK_EXPORT int ck_model_kv_cache_enable(int capacity)\n"
1349  "{\n"
1350  " if (!g_initialized) return -1;\n"
1351  " if (g_model.training_enabled) return -4;\n"
1352  " g_model.kv_cache_enabled = true;\n"
1353  " int cap = capacity;\n"
1354  " if (cap <= 0 || cap > g_model.context_window) cap = g_model.context_window;\n"
1355  " g_model.kv_cache_capacity = cap;\n"
1356  " g_model.kv_cache_tokens = 0;\n"
1357  " g_model.active_tokens = 0;\n"
1358  " return 0;\n"
1359  "}\n\n"
1360  "CK_EXPORT void ck_model_kv_cache_reset(void)\n"
1361  "{\n"
1362  " if (!g_initialized) return;\n"
1363  " g_model.kv_cache_tokens = 0;\n"
1364  " g_model.active_tokens = 0;\n"
1365  "}\n\n"
1366  "CK_EXPORT int ck_model_kv_cache_get_tokens(void)\n"
1367  "{\n"
1368  " return g_initialized ? g_model.kv_cache_tokens : 0;\n"
1369  "}\n\n"
1370  "CK_EXPORT int ck_model_decode(int32_t token, float *logits_out)\n"
1371  "{\n"
1372  " if (!g_initialized) return -1;\n"
1373  " if (g_model.training_enabled) return -4;\n"
1374  " int ret = run_model_decode(&g_model, token);\n"
1375  " if (ret != 0) return ret;\n"
1376  " if (logits_out && g_model.vocab_size > 0) {\n"
1377  " int t = g_model.active_tokens - 1;\n"
1378  " memcpy(logits_out,\n"
1379  " ptr_f32(g_model.memory_base, g_model.logits_offset) + (size_t)t * (size_t)g_model.vocab_size,\n"
1380  " (size_t)g_model.vocab_size * sizeof(float));\n"
1381  " }\n"
1382  " return 0;\n"
1383  "}\n\n");
1384 
1385  /* ck_model_get_logits - get pointer to internal logits buffer */
1386  fprintf(out,
1387  "CK_EXPORT float* ck_model_get_logits(void)\n"
1388  "{\n"
1389  " if (!g_initialized) return NULL;\n"
1390  " return ptr_f32(g_model.memory_base, g_model.logits_offset);\n"
1391  "}\n\n");
1392 
1393  /* ck_model_backward */
1394  fprintf(out,
1395  "CK_EXPORT int ck_model_backward(const int32_t *tokens, const int32_t *targets, float *loss_out)\n"
1396  "{\n"
1397  " if (!g_initialized) return -1;\n"
1398  " return run_model_backward(&g_model, tokens, targets, loss_out);\n"
1399  "}\n\n");
1400 
1401  /* ck_model_free */
1402  fprintf(out,
1403  "CK_EXPORT void ck_model_free(void)\n"
1404  "{\n"
1405  " if (!g_initialized) return;\n"
1406  " if (g_model.memory_base) ck_huge_free(g_model.memory_base, g_model.total_bytes);\n"
1407  " if (g_model.layers) free(g_model.layers);\n"
1408  " memset(&g_model, 0, sizeof(g_model));\n"
1409  " g_initialized = 0;\n"
1410  "}\n\n");
1411 
1412  /* ck_model_get_context_window */
1413  fprintf(out,
1414  "CK_EXPORT int ck_model_get_context_window(void) { return g_initialized ? g_model.context_window : 0; }\n"
1415  "CK_EXPORT int ck_model_get_vocab_size(void) { return g_initialized ? g_model.vocab_size : 0; }\n"
1416  "CK_EXPORT int ck_model_get_hidden_size(void) { return g_initialized ? g_model.embed_dim : 0; }\n"
1417  "CK_EXPORT int ck_model_get_active_tokens(void) { return g_initialized ? g_model.active_tokens : 0; }\n"
1418  "CK_EXPORT int ck_model_is_training_enabled(void) { return g_initialized ? g_model.training_enabled : 0; }\n"
1419  "CK_EXPORT void ck_model_set_learning_rate(float lr) { if (g_initialized) g_model.learning_rate = lr; }\n"
1420  "CK_EXPORT float ck_model_get_learning_rate(void) { return g_initialized ? g_model.learning_rate : 0.0f; }\n\n"
1421  "CK_EXPORT int ck_model_enable_training(float learning_rate)\n"
1422  "{\n"
1423  " if (!g_initialized) return -1;\n"
1424  " g_model.training_enabled = true;\n"
1425  " g_model.learning_rate = learning_rate;\n"
1426  " return 0;\n"
1427  "}\n\n"
1428  "CK_EXPORT void ck_model_disable_training(void)\n"
1429  "{\n"
1430  " if (g_initialized) g_model.training_enabled = false;\n"
1431  "}\n\n"
1432  "CK_EXPORT void ck_model_optimizer_step(void)\n"
1433  "{\n"
1434  " if (!g_initialized || !g_model.training_enabled) return;\n"
1435  " sgd_update(&g_model, g_model.learning_rate);\n"
1436  "}\n\n");
1437 }

References CKIRGraph::config, CKModelConfig::context_window, CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, and CKModelConfig::vocab_size.

Referenced by ck_codegen_emit_runtime().

◆ emit_model_struct()

static void emit_model_struct ( FILE *  out)
static

Definition at line 148 of file ckernel_codegen_v6.5.c.

149 {
150  fprintf(out,
151  "typedef LayerOffsets TrulyOptimalLayer;\n\n"
152  "typedef struct {\n"
153  " char magic[8];\n"
154  " uint32_t version;\n"
155  " uint32_t model_type;\n"
156  "\n"
157  " int num_layers;\n"
158  " int vocab_size;\n"
159  " int embed_dim;\n"
160  " int context_window;\n"
161  " int intermediate_size;\n"
162  "\n"
163  " size_t aligned_embed_dim;\n"
164  " size_t aligned_head_dim;\n"
165  " size_t aligned_attn_context_window;\n"
166  "\n"
167  " int num_cores;\n"
168  " int tokens_per_core;\n"
169  " int num_attention_heads;\n"
170  " int num_kv_heads;\n"
171  " int head_dim;\n"
172  " float rms_norm_eps;\n"
173  " float rope_theta;\n"
174  "\n"
175  " uint8_t *memory_base;\n"
176  " size_t total_bytes;\n"
177  " size_t elem_bytes;\n"
178  " CKDataType weight_dtype;\n"
179  " CKDataType token_emb_dtype;\n"
180  " CKDataType pos_emb_dtype;\n"
181  " CKDataType lm_head_weight_dtype;\n"
182  " bool weights_mixed;\n"
183  " bool weights_quantized;\n"
184  " size_t layer_stride;\n"
185  "\n"
186  " size_t layers_start_offset;\n");
187 
189 
190  fprintf(out,
191  "\n"
192  " TrulyOptimalLayer *layers;\n"
193  "\n"
194  " GradientStorage gradients;\n"
195  " bool training_enabled;\n"
196  " float learning_rate;\n"
197  " int lr_warmup_steps;\n"
198  " float lr_warmup_init;\n"
199  " float grad_clip;\n"
200  " size_t training_cache_samples;\n"
201  " int active_tokens;\n"
202  " TaskType task_type;\n"
203  " OptimizerType optimizer;\n"
204  " uint64_t optimizer_step;\n"
205  " float adam_beta1;\n"
206  " float adam_beta2;\n"
207  " float adam_eps;\n"
208  " float weight_decay;\n"
209  " bool ema_enabled;\n"
210  " float ema_decay;\n"
211  " bool optimizer_state_initialized;\n"
212  "\n"
213  " bool seq_cls_enabled;\n"
214  " int seq_cls_num_classes;\n"
215  " int seq_cls_pooling;\n"
216  " size_t seq_cls_weight_offset;\n"
217  " size_t seq_cls_bias_offset;\n"
218  "\n"
219  " bool kv_cache_enabled;\n"
220  " int kv_cache_capacity;\n"
221  " int kv_cache_tokens;\n"
222  "\n"
223  " long *training_data_buffer;\n"
224  " long num_training_tokens;\n"
225  "\n"
226  " uint8_t checksum[32];\n"
227  " uint8_t reserved[32];\n"
228  "} TransformerModel;\n\n");
229 }
static void emit_global_offset_fields(FILE *out)

References emit_global_offset_fields().

Referenced by ck_codegen_emit_runtime().

◆ emit_offset_field()

static void emit_offset_field ( FILE *  out,
const char *  name 
)
static

Definition at line 106 of file ckernel_codegen_v6.5.c.

107 {
108  fprintf(out, " size_t %s_offset;\n", name);
109 }

Referenced by emit_global_offset_fields(), and emit_layer_offsets_struct().

◆ emit_plan_sources()

static int emit_plan_sources ( FILE *  f,
const CKPlanStep plan,
size_t  plan_count,
const CKIRGraph cfg,
const char **  seen,
size_t *  seen_count,
size_t  seen_cap 
)
static

Definition at line 568 of file ckernel_codegen_v6.5.c.

575 {
576  for (size_t i = 0; i < plan_count; ++i) {
577  const CKPlanStep *step = &plan[i];
578  if (!ck_plan_step_enabled(step, cfg)) {
579  continue;
580  }
581  const CKKernelSpec *spec = ck_find_kernel_spec(step->kernel);
582  if (!spec) {
583  continue;
584  }
585  for (size_t s = 0; s < CKERNEL_MAX_KERNEL_SOURCES; ++s) {
586  const char *src = spec->sources[s];
587  if (!src) {
588  continue;
589  }
590  if (emit_unique_source(f, src, seen, seen_count, seen_cap) != 0) {
591  return -1;
592  }
593  }
594  }
595  return 0;
596 }
static const CKKernelSpec * ck_find_kernel_spec(const char *name)
static int ck_plan_step_enabled(const CKPlanStep *step, const CKIRGraph *cfg)
const char * sources[8]
const char * kernel

References ck_find_kernel_spec(), ck_plan_step_enabled(), CKERNEL_MAX_KERNEL_SOURCES, emit_unique_source(), CKPlanStep::kernel, and CKKernelSpec::sources.

Referenced by emit_kernel_manifest().

◆ emit_runtime_preamble()

static int emit_runtime_preamble ( FILE *  out)
static

Definition at line 758 of file ckernel_codegen_v6.5.c.

759 {
760  fprintf(out,
761  "/* Auto-generated runtime from CKIRGraph.\n"
762  " * This file wires the existing C-Kernel-Engine kernels into a\n"
763  " * decoder-only transformer forward pass.\n"
764  " *\n"
765  " * Compile (scalar): gcc -O2 generated_model.c $(cat generated_model.c.kernels) -Iinclude -lm -o generated_model\n"
766  " * Compile (AVX-512): gcc -O3 -mavx512f -mfma generated_model.c $(cat generated_model.c.kernels) -Iinclude -lm -o generated_model\n"
767  " */\n\n");
768 
769  fprintf(out,
770  "#define _GNU_SOURCE\n"
771  "#include <stddef.h>\n"
772  "#include <stdint.h>\n"
773  "#include <stdbool.h>\n"
774  "#include <stdio.h>\n"
775  "#include <stdlib.h>\n"
776  "#include <string.h>\n"
777  "#include <math.h>\n"
778  "#include <errno.h>\n"
779  "#include <sys/types.h>\n"
780  "#include <unistd.h>\n"
781  "#include \"ckernel_engine.h\"\n"
782  "#include \"ckernel_dtype.h\"\n"
783  "#include \"ckernel_orchestration.h\"\n"
784  "#include \"ckernel_alloc.h\"\n\n");
785 
786  fprintf(out,
787  "#define CACHELINE_BYTES 64\n"
788  "static size_t align_up_bytes(size_t n, size_t align) {\n"
789  " if (align == 0) return n;\n"
790  " return (n + align - 1) & ~(align - 1);\n"
791  "}\n\n"
792  "static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align) {\n"
793  " size_t bytes = elems * elem_bytes;\n"
794  " bytes = align_up_bytes(bytes, align);\n"
795  " return bytes / elem_bytes;\n"
796  "}\n\n"
797  "static size_t bump_bytes(size_t *off, size_t bytes, size_t align) {\n"
798  " size_t start = align_up_bytes(*off, align);\n"
799  " *off = start + bytes;\n"
800  " return start;\n"
801  "}\n\n"
802  "static inline float *ptr_f32(uint8_t *base, size_t offset) {\n"
803  " return (float *)(base + offset);\n"
804  "}\n"
805  "static inline const float *cptr_f32(const uint8_t *base, size_t offset) {\n"
806  " return (const float *)(base + offset);\n"
807  "}\n\n");
808 
809  fprintf(out,
810  "static inline uint8_t *ptr_u8(uint8_t *base, size_t offset) {\n"
811  " return base + offset;\n"
812  "}\n"
813  "static inline const void *cptr_void(const uint8_t *base, size_t offset) {\n"
814  " return (const void *)(base + offset);\n"
815  "}\n\n");
816 
817  return 0;
818 }

Referenced by ck_codegen_emit_runtime().

◆ emit_sgd_update()

static void emit_sgd_update ( FILE *  out)
static

Definition at line 459 of file ckernel_codegen_v6.5.c.

460 {
461  fprintf(out,
462  "static void sgd_update(TransformerModel *m, float lr)\n"
463  "{\n"
464  " if (!m || !m->training_enabled || lr == 0.0f) return;\n"
465  " uint8_t *base = m->memory_base;\n"
466  " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n");
467 
468  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
469  const CKBufferSpec *spec = &ck_decoder_buffers[i];
470  if (spec->role != CK_ROLE_WEIGHT || spec->scope != CK_SCOPE_GLOBAL) {
471  continue;
472  }
473  if (spec->alias_of) {
474  continue;
475  }
476  char grad_name[128];
477  snprintf(grad_name, sizeof(grad_name), "d_%s", spec->name);
478  const CKBufferSpec *grad = ck_find_buffer_spec(grad_name);
479  if (!grad || grad->scope != CK_SCOPE_GLOBAL) {
480  continue;
481  }
482  fprintf(out,
483  " if (m->%s_offset && m->%s_offset) {\n"
484  " float *w = ptr_f32(base, m->%s_offset);\n"
485  " float *g = ptr_f32(base, m->%s_offset);\n"
486  " size_t count = (",
487  spec->name, grad_name, spec->name, grad_name);
488  emit_shape_expr(out, spec->shape);
489  fprintf(out,
490  ");\n"
491  " for (size_t i = 0; i < count; ++i) {\n"
492  " w[i] -= lr * g[i];\n"
493  " }\n"
494  " }\n");
495  }
496 
497  fprintf(out,
498  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
499  " TrulyOptimalLayer *L = &m->layers[layer];\n");
500  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
501  const CKBufferSpec *spec = &ck_decoder_buffers[i];
502  if (spec->role != CK_ROLE_WEIGHT || spec->scope != CK_SCOPE_LAYER) {
503  continue;
504  }
505  char grad_name[128];
506  snprintf(grad_name, sizeof(grad_name), "d_%s", spec->name);
507  const CKBufferSpec *grad = ck_find_buffer_spec(grad_name);
508  if (!grad || grad->scope != CK_SCOPE_LAYER) {
509  continue;
510  }
511  fprintf(out,
512  " if (L->%s_offset && L->%s_offset) {\n"
513  " float *w = ptr_f32(base, L->%s_offset);\n"
514  " float *g = ptr_f32(base, L->%s_offset);\n"
515  " size_t count = (",
516  spec->name, grad_name, spec->name, grad_name);
517  emit_shape_expr(out, spec->shape);
518  fprintf(out,
519  ");\n"
520  " for (size_t i = 0; i < count; ++i) {\n"
521  " w[i] -= lr * g[i];\n"
522  " }\n"
523  " }\n");
524  }
525  fprintf(out,
526  " }\n"
527  "}\n\n");
528 }

References CKBufferSpec::alias_of, ck_decoder_buffer_count, ck_decoder_buffers, ck_find_buffer_spec(), CK_ROLE_WEIGHT, CK_SCOPE_GLOBAL, CK_SCOPE_LAYER, emit_shape_expr(), CKBufferSpec::name, CKBufferSpec::role, CKBufferSpec::scope, and CKBufferSpec::shape.

Referenced by ck_codegen_emit_runtime().

◆ emit_shape_expr()

static void emit_shape_expr ( FILE *  out,
const CKDimToken shape 
)
static

Definition at line 249 of file ckernel_codegen_v6.5.c.

250 {
251  int first = 1;
252  for (int i = 0; i < 4; ++i) {
253  if (shape[i].dim == CK_DIM_END) {
254  break;
255  }
256  if (!first) {
257  fprintf(out, " * ");
258  }
259  fprintf(out, "(");
260  emit_dim_expr(out, shape[i].dim);
261  if (shape[i].mult != 1) {
262  fprintf(out, " * %d", shape[i].mult);
263  }
264  if (shape[i].div != 1) {
265  fprintf(out, " / %d", shape[i].div);
266  }
267  fprintf(out, ")");
268  first = 0;
269  }
270  if (first) {
271  fprintf(out, "0");
272  }
273 }
static void emit_dim_expr(FILE *out, CKDimKind dim)

References CK_DIM_END, and emit_dim_expr().

Referenced by emit_bump_bytes_assignment(), emit_bump_bytes_assignment_weight_dtype(), emit_global_allocations(), emit_layer_allocations(), emit_sgd_update(), emit_training_conditional_assignment(), and emit_zero_grad().

◆ emit_training_conditional_assignment()

static void emit_training_conditional_assignment ( FILE *  out,
const char *  indent,
const char *  struct_prefix,
const char *  name,
const CKDimToken shape 
)
static

Definition at line 299 of file ckernel_codegen_v6.5.c.

304 {
305  /* Allocate gradient buffers only if training is enabled at init time */
306  fprintf(out, "%s%s%s_offset = m->training_enabled ? bump_bytes(&off, (", indent, struct_prefix, name);
307  emit_shape_expr(out, shape);
308  fprintf(out, ") * elem_bytes, CACHELINE_BYTES) : 0;\n");
309 }

References emit_shape_expr().

Referenced by emit_global_allocations(), and emit_layer_allocations().

◆ emit_unique_source()

static int emit_unique_source ( FILE *  f,
const char *  path,
const char **  seen,
size_t *  seen_count,
size_t  seen_cap 
)
static

Definition at line 530 of file ckernel_codegen_v6.5.c.

535 {
536  if (!path || !path[0]) {
537  return 0;
538  }
539  for (size_t i = 0; i < *seen_count; ++i) {
540  if (strcmp(seen[i], path) == 0) {
541  return 0;
542  }
543  }
544  if (*seen_count >= seen_cap) {
545  return -1;
546  }
547  fputs(path, f);
548  fputc('\n', f);
549  seen[*seen_count] = path;
550  (*seen_count)++;
551  return 0;
552 }

Referenced by emit_kernel_manifest(), and emit_plan_sources().

◆ emit_zero_grad()

static void emit_zero_grad ( FILE *  out)
static

Definition at line 419 of file ckernel_codegen_v6.5.c.

420 {
421  fprintf(out,
422  "static void zero_grad(TransformerModel *m)\n"
423  "{\n"
424  " if (!m || !m->training_enabled) return;\n"
425  " uint8_t *base = m->memory_base;\n"
426  " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n");
427 
428  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
429  const CKBufferSpec *spec = &ck_decoder_buffers[i];
430  if (spec->role != CK_ROLE_GRAD || spec->scope != CK_SCOPE_GLOBAL) {
431  continue;
432  }
433  fprintf(out, " if (m->%s_offset) {\n", spec->name);
434  fprintf(out, " memset(base + m->%s_offset, 0, (", spec->name);
435  emit_shape_expr(out, spec->shape);
436  fprintf(out, ") * m->elem_bytes);\n");
437  fprintf(out, " }\n");
438  }
439 
440  fprintf(out,
441  " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
442  " TrulyOptimalLayer *L = &m->layers[layer];\n");
443  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
444  const CKBufferSpec *spec = &ck_decoder_buffers[i];
445  if (spec->role != CK_ROLE_GRAD || spec->scope != CK_SCOPE_LAYER) {
446  continue;
447  }
448  fprintf(out, " if (L->%s_offset) {\n", spec->name);
449  fprintf(out, " memset(base + L->%s_offset, 0, (", spec->name);
450  emit_shape_expr(out, spec->shape);
451  fprintf(out, ") * m->elem_bytes);\n");
452  fprintf(out, " }\n");
453  }
454  fprintf(out,
455  " }\n"
456  "}\n\n");
457 }

References ck_decoder_buffer_count, ck_decoder_buffers, CK_ROLE_GRAD, CK_SCOPE_GLOBAL, CK_SCOPE_LAYER, emit_shape_expr(), CKBufferSpec::name, CKBufferSpec::role, CKBufferSpec::scope, and CKBufferSpec::shape.

Referenced by ck_codegen_emit_runtime().

◆ op_name()

static const char* op_name ( CKOpType  op)
static

Definition at line 10 of file ckernel_codegen_v6.5.c.

11 {
12  switch (op) {
13  case CK_OP_RMSNORM: return "RMSNORM";
14  case CK_OP_LINEAR_QKV: return "LINEAR_QKV";
15  case CK_OP_ATTENTION: return "ATTENTION";
16  case CK_OP_ADD: return "ADD";
17  case CK_OP_LINEAR: return "LINEAR";
18  case CK_OP_SPLIT: return "SPLIT";
19  case CK_OP_SWIGLU: return "SWIGLU";
20  case CK_OP_RMSNORM_BWD: return "RMSNORM_BWD";
21  case CK_OP_LINEAR_QKV_BWD: return "LINEAR_QKV_BWD";
22  case CK_OP_ATTENTION_BWD: return "ATTENTION_BWD";
23  case CK_OP_ADD_BWD: return "ADD_BWD";
24  case CK_OP_LINEAR_BWD: return "LINEAR_BWD";
25  case CK_OP_SPLIT_BWD: return "SPLIT_BWD";
26  case CK_OP_SWIGLU_BWD: return "SWIGLU_BWD";
27  default: return "UNKNOWN";
28  }
29 }
@ CK_OP_LINEAR_BWD
Definition: ckernel_ir.h:48
@ CK_OP_SWIGLU
Definition: ckernel_ir.h:42
@ CK_OP_RMSNORM_BWD
Definition: ckernel_ir.h:44
@ CK_OP_SWIGLU_BWD
Definition: ckernel_ir.h:50
@ CK_OP_ADD
Definition: ckernel_ir.h:39
@ CK_OP_SPLIT
Definition: ckernel_ir.h:41
@ CK_OP_LINEAR_QKV_BWD
Definition: ckernel_ir.h:45
@ CK_OP_ATTENTION_BWD
Definition: ckernel_ir.h:46
@ CK_OP_SPLIT_BWD
Definition: ckernel_ir.h:49
@ CK_OP_LINEAR_QKV
Definition: ckernel_ir.h:37
@ CK_OP_LINEAR
Definition: ckernel_ir.h:40
@ CK_OP_RMSNORM
Definition: ckernel_ir.h:36
@ CK_OP_ADD_BWD
Definition: ckernel_ir.h:47
@ CK_OP_ATTENTION
Definition: ckernel_ir.h:38

References CK_OP_ADD, CK_OP_ADD_BWD, CK_OP_ATTENTION, CK_OP_ATTENTION_BWD, CK_OP_LINEAR, CK_OP_LINEAR_BWD, CK_OP_LINEAR_QKV, CK_OP_LINEAR_QKV_BWD, CK_OP_RMSNORM, CK_OP_RMSNORM_BWD, CK_OP_SPLIT, CK_OP_SPLIT_BWD, CK_OP_SWIGLU, and CK_OP_SWIGLU_BWD.

Referenced by ck_codegen_c_skeleton().