Emit a C runtime file that stitches kernels for the given forward IR.
- Parameters
-
| forward | The forward IR graph |
| path | Output file path |
| mode | CK_EMIT_STANDALONE for executable with main(), CK_EMIT_LIBRARY for shared object with API functions |
Returns 0 on success, non-zero on failure.
Definition at line 1441 of file ckernel_codegen.c.
1443 if (!forward || !path) {
1450 FILE *out = fopen(path,
"wb");
1452 fprintf(stderr,
"ck_codegen_emit_runtime: failed to open %s: %s\n",
1453 path, strerror(errno));
1465 " TASK_SEQ_CLS = 1\n"
1468 " OPTIMIZER_SGD = 0,\n"
1469 " OPTIMIZER_ADAM = 1\n"
1470 "} OptimizerType;\n\n"
1471 "typedef struct {\n"
1472 " size_t total_gradient_floats;\n"
1473 "} GradientStorage;\n\n");
1479 "static int ensure_layers_allocated(TransformerModel *m)\n"
1481 " if (!m) return -1;\n"
1482 " if (!m->layers && m->num_layers > 0) {\n"
1483 " m->layers = (TrulyOptimalLayer *)calloc((size_t)m->num_layers, sizeof(TrulyOptimalLayer));\n"
1484 " if (!m->layers) return -1;\n"
1488 "static void init_weight_dtypes_uniform(TransformerModel *m, CKDataType dt)\n"
1490 " if (!m) return;\n"
1491 " m->token_emb_dtype = dt;\n"
1492 " m->lm_head_weight_dtype = dt;\n"
1493 " m->pos_emb_dtype = CK_DT_FP32;\n"
1494 " if (ensure_layers_allocated(m) != 0) return;\n"
1495 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1496 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1497 " L->wq_dtype = dt;\n"
1498 " L->wk_dtype = dt;\n"
1499 " L->wv_dtype = dt;\n"
1500 " L->wo_dtype = dt;\n"
1501 " L->w1_dtype = dt;\n"
1502 " L->w2_dtype = dt;\n"
1505 "static void refresh_weight_flags(TransformerModel *m)\n"
1507 " if (!m) return;\n"
1508 " CKDataType base = m->token_emb_dtype;\n"
1510 " int quant = ck_dtype_is_quantized(base);\n"
1511 " if (m->lm_head_weight_dtype != base) mixed = 1;\n"
1512 " if (ck_dtype_is_quantized(m->lm_head_weight_dtype)) quant = 1;\n"
1513 " if (m->layers) {\n"
1514 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1515 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1516 " if (L->wq_dtype != base || L->wk_dtype != base || L->wv_dtype != base ||\n"
1517 " L->wo_dtype != base || L->w1_dtype != base || L->w2_dtype != base) {\n"
1520 " if (ck_dtype_is_quantized(L->wq_dtype) || ck_dtype_is_quantized(L->wk_dtype) ||\n"
1521 " ck_dtype_is_quantized(L->wv_dtype) || ck_dtype_is_quantized(L->wo_dtype) ||\n"
1522 " ck_dtype_is_quantized(L->w1_dtype) || ck_dtype_is_quantized(L->w2_dtype)) {\n"
1527 " m->weights_mixed = mixed ? true : false;\n"
1528 " m->weights_quantized = quant ? true : false;\n"
1530 " m->weight_dtype = base;\n"
1533 "static int load_weight_dtypes(const char *path, TransformerModel *m)\n"
1535 " if (!path || !m) return -1;\n"
1536 " FILE *f = fopen(path, \"rb\");\n"
1537 " if (!f) return -1;\n"
1539 " if (fread(magic, 1, 8, f) != 8) {\n"
1543 " if (memcmp(magic, \"BUMPWGT3\", 8) != 0) {\n"
1547 " uint32_t version = 0;\n"
1548 " if (fread(&version, sizeof(uint32_t), 1, f) != 1) {\n"
1552 " if (version < 3) {\n"
1556 " if (fseek(f, 128, SEEK_SET) != 0) {\n"
1560 " uint32_t dtype_len = 0;\n"
1561 " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) {\n"
1565 " if (dtype_len == 0) {\n"
1569 " uint8_t *dtype_buf = (uint8_t *)malloc(dtype_len);\n"
1570 " if (!dtype_buf) {\n"
1574 " if (fread(dtype_buf, 1, dtype_len, f) != dtype_len) {\n"
1575 " free(dtype_buf);\n"
1581 " size_t expected = (size_t)m->num_layers * 14u + 4u;\n"
1582 " if (dtype_len != expected) {\n"
1583 " free(dtype_buf);\n"
1586 " if (ensure_layers_allocated(m) != 0) {\n"
1587 " free(dtype_buf);\n"
1591 " size_t idx = 0;\n"
1592 " CKDataType token_dt = (CKDataType)dtype_buf[idx++];\n"
1593 " CKDataType pos_dt = (CKDataType)dtype_buf[idx++];\n"
1594 " if (pos_dt != CK_DT_FP32) {\n"
1595 " free(dtype_buf);\n"
1598 " if (token_dt != CK_DT_FP32 && token_dt != CK_DT_Q4_K && token_dt != CK_DT_Q6_K) {\n"
1599 " free(dtype_buf);\n"
1602 " m->token_emb_dtype = token_dt;\n"
1603 " m->lm_head_weight_dtype = token_dt;\n"
1604 " m->pos_emb_dtype = pos_dt;\n"
1606 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1607 " CKDataType ln1_dt = (CKDataType)dtype_buf[idx++];\n"
1608 " CKDataType ln2_dt = (CKDataType)dtype_buf[idx++];\n"
1609 " CKDataType wq_dt = (CKDataType)dtype_buf[idx++];\n"
1610 " CKDataType bq_dt = (CKDataType)dtype_buf[idx++];\n"
1611 " CKDataType wk_dt = (CKDataType)dtype_buf[idx++];\n"
1612 " CKDataType bk_dt = (CKDataType)dtype_buf[idx++];\n"
1613 " CKDataType wv_dt = (CKDataType)dtype_buf[idx++];\n"
1614 " CKDataType bv_dt = (CKDataType)dtype_buf[idx++];\n"
1615 " CKDataType wo_dt = (CKDataType)dtype_buf[idx++];\n"
1616 " CKDataType bo_dt = (CKDataType)dtype_buf[idx++];\n"
1617 " CKDataType w1_dt = (CKDataType)dtype_buf[idx++];\n"
1618 " CKDataType b1_dt = (CKDataType)dtype_buf[idx++];\n"
1619 " CKDataType w2_dt = (CKDataType)dtype_buf[idx++];\n"
1620 " CKDataType b2_dt = (CKDataType)dtype_buf[idx++];\n"
1622 " if (ln1_dt != CK_DT_FP32 || ln2_dt != CK_DT_FP32 ||\n"
1623 " bq_dt != CK_DT_FP32 || bk_dt != CK_DT_FP32 ||\n"
1624 " bv_dt != CK_DT_FP32 || bo_dt != CK_DT_FP32 ||\n"
1625 " b1_dt != CK_DT_FP32 || b2_dt != CK_DT_FP32) {\n"
1626 " free(dtype_buf);\n"
1629 " if ((wq_dt != CK_DT_FP32 && wq_dt != CK_DT_Q4_K && wq_dt != CK_DT_Q6_K) ||\n"
1630 " (wk_dt != CK_DT_FP32 && wk_dt != CK_DT_Q4_K && wk_dt != CK_DT_Q6_K) ||\n"
1631 " (wv_dt != CK_DT_FP32 && wv_dt != CK_DT_Q4_K && wv_dt != CK_DT_Q6_K) ||\n"
1632 " (wo_dt != CK_DT_FP32 && wo_dt != CK_DT_Q4_K && wo_dt != CK_DT_Q6_K) ||\n"
1633 " (w1_dt != CK_DT_FP32 && w1_dt != CK_DT_Q4_K && w1_dt != CK_DT_Q6_K) ||\n"
1634 " (w2_dt != CK_DT_FP32 && w2_dt != CK_DT_Q4_K && w2_dt != CK_DT_Q6_K)) {\n"
1635 " free(dtype_buf);\n"
1639 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1640 " L->wq_dtype = wq_dt;\n"
1641 " L->wk_dtype = wk_dt;\n"
1642 " L->wv_dtype = wv_dt;\n"
1643 " L->wo_dtype = wo_dt;\n"
1644 " L->w1_dtype = w1_dt;\n"
1645 " L->w2_dtype = w2_dt;\n"
1648 " CKDataType final_norm_dt = (CKDataType)dtype_buf[idx++];\n"
1649 " CKDataType final_bias_dt = (CKDataType)dtype_buf[idx++];\n"
1650 " free(dtype_buf);\n"
1651 " if (final_norm_dt != CK_DT_FP32 || final_bias_dt != CK_DT_FP32) {\n"
1655 " refresh_weight_flags(m);\n"
1659 "static int layout_model(TransformerModel *m)\n"
1661 " if (!m) return -1;\n"
1662 " if (m->num_attention_heads <= 0 || m->embed_dim <= 0) return -1;\n"
1663 " if (m->num_kv_heads <= 0) m->num_kv_heads = m->num_attention_heads;\n"
1664 " if (m->num_attention_heads %% m->num_kv_heads != 0) return -1;\n"
1665 " if (m->context_window <= 0) m->context_window = 1;\n"
1666 " if (m->vocab_size <= 0) m->vocab_size = 1;\n"
1667 " if (m->intermediate_size <= 0) return -1;\n"
1668 " m->head_dim = m->embed_dim / m->num_attention_heads;\n"
1669 " if (m->rms_norm_eps <= 0.0f) m->rms_norm_eps = 1e-5f;\n"
1670 " if (m->rope_theta < 0.0f) m->rope_theta = 0.0f;\n"
1671 " if (m->rope_theta > 0.0f && (m->head_dim %% 2 != 0)) return -1;\n"
1672 " if (m->elem_bytes == 0) m->elem_bytes = sizeof(float);\n"
1673 " size_t elem_bytes = m->elem_bytes;\n"
1674 " m->aligned_embed_dim = align_up_elems((size_t)m->embed_dim, elem_bytes, CACHELINE_BYTES);\n"
1675 " m->aligned_head_dim = align_up_elems((size_t)m->head_dim, elem_bytes, CACHELINE_BYTES);\n"
1676 " m->aligned_attn_context_window = align_up_elems((size_t)m->context_window, elem_bytes, CACHELINE_BYTES);\n"
1677 " size_t aligned_intermediate_dim = align_up_elems((size_t)m->intermediate_size, elem_bytes, CACHELINE_BYTES);\n"
1678 " if (ensure_layers_allocated(m) != 0) return -1;\n"
1679 " if (m->weights_quantized) {\n"
1680 " /* K-quant weights require K dimension to be a multiple of 256. */\n"
1681 " if ((m->aligned_embed_dim %% 256) != 0) return -1;\n"
1682 " if ((aligned_intermediate_dim %% 256) != 0) return -1;\n"
1683 " int wo_quant = 0;\n"
1684 " if (m->layers) {\n"
1685 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1686 " if (ck_dtype_is_quantized(m->layers[layer].wo_dtype)) {\n"
1692 " if (wo_quant && (size_t)m->num_attention_heads * m->aligned_head_dim != m->aligned_embed_dim) return -1;\n"
1695 " if (m->num_cores <= 0) m->num_cores = 1;\n"
1696 " m->tokens_per_core = (m->context_window + m->num_cores - 1) / m->num_cores;\n"
1698 " size_t off = 0;\n");
1701 " m->layers_start_offset = off;\n"
1703 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1704 " TrulyOptimalLayer *L = &m->layers[layer];\n");
1712 " if (m->num_layers > 1) {\n"
1713 " m->layer_stride = m->layers[1].%s_offset - m->layers[0].%s_offset;\n"
1715 " m->layer_stride = 0;\n"
1717 stride_field, stride_field);
1721 " m->total_bytes = align_up_bytes(off, CACHELINE_BYTES);\n"
1722 " m->memory_base = (uint8_t *)ck_huge_alloc(m->total_bytes);\n"
1723 " if (!m->memory_base) return -1;\n"
1724 " if (m->rope_theta > 0.0f) {\n"
1725 " rope_precompute_cache(ptr_f32(m->memory_base, m->rope_cos_cache_offset),\n"
1726 " ptr_f32(m->memory_base, m->rope_sin_cache_offset),\n"
1727 " m->context_window,\n"
1729 " m->rope_theta);\n"
1735 "static void lm_head_forward(const float *hidden,\n"
1736 " const float *weights,\n"
1738 " int T, int V, int D, int aligned_D);\n"
1739 "static void lm_head_backward(const float *hidden,\n"
1740 " const float *weights,\n"
1741 " const float *d_logits,\n"
1742 " float *d_hidden,\n"
1743 " float *d_weights,\n"
1744 " int T, int V, int D, int aligned_D);\n"
1745 "static void softmax_cross_entropy(const float *logits,\n"
1746 " const int32_t *targets,\n"
1748 " float *d_logits,\n"
1749 " float *loss_out);\n\n");
1752 "static void run_model_forward(TransformerModel *m)\n"
1754 " uint8_t *base = m->memory_base;\n"
1755 " float *current = ptr_f32(base, m->embedded_input_offset);\n"
1756 " int aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
1757 " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
1758 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
1759 " TrulyOptimalLayer *L = &m->layers[layer];\n"
1760 " if (!m->weights_mixed && m->weight_dtype == CK_DT_Q4_K) {\n"
1761 " CKLayerForwardParamsQ4K p = {0};\n"
1763 " p.embed_dim = m->embed_dim;\n"
1764 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1765 " p.num_heads = m->num_attention_heads;\n"
1766 " p.num_kv_heads = m->num_kv_heads;\n"
1767 " p.head_dim = m->head_dim;\n"
1768 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1769 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1770 " p.intermediate_dim = m->intermediate_size;\n"
1771 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1772 " p.eps = m->rms_norm_eps;\n"
1773 " p.rope_pos_offset = 0;\n"
1774 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1775 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1776 " p.input = current;\n"
1777 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1778 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1779 " p.wq = cptr_void(base, L->wq_offset);\n"
1780 " p.bq = cptr_f32(base, L->bq_offset);\n"
1781 " p.wk = cptr_void(base, L->wk_offset);\n"
1782 " p.bk = cptr_f32(base, L->bk_offset);\n"
1783 " p.wv = cptr_void(base, L->wv_offset);\n"
1784 " p.bv = cptr_f32(base, L->bv_offset);\n"
1785 " p.wo = cptr_void(base, L->wo_offset);\n"
1786 " p.bo = cptr_f32(base, L->bo_offset);\n"
1787 " p.w1 = cptr_void(base, L->w1_offset);\n"
1788 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1789 " p.w2 = cptr_void(base, L->w2_offset);\n"
1790 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1791 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1792 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1793 " p.q = ptr_f32(base, L->q_offset);\n"
1794 " p.k = ptr_f32(base, L->k_offset);\n"
1795 " p.v = ptr_f32(base, L->v_offset);\n"
1796 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1797 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1798 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1799 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1800 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1801 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1802 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1803 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1804 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1805 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1806 " p.output = ptr_f32(base, L->output_offset);\n"
1807 " ck_layer_forward_rmsnorm_swiglu_q4_k(&p);\n"
1808 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1809 " kv_cache_repack_head_major_inplace(p.k,\n"
1810 " p.num_kv_heads,\n"
1812 " m->kv_cache_capacity,\n"
1813 " p.aligned_head_dim);\n"
1814 " kv_cache_repack_head_major_inplace(p.v,\n"
1815 " p.num_kv_heads,\n"
1817 " m->kv_cache_capacity,\n"
1818 " p.aligned_head_dim);\n"
1820 " current = p.output;\n"
1821 " } else if (m->weights_quantized) {\n"
1822 " CKLayerForwardParamsQ4K p = {0};\n"
1824 " p.embed_dim = m->embed_dim;\n"
1825 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1826 " p.num_heads = m->num_attention_heads;\n"
1827 " p.num_kv_heads = m->num_kv_heads;\n"
1828 " p.head_dim = m->head_dim;\n"
1829 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1830 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1831 " p.intermediate_dim = m->intermediate_size;\n"
1832 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1833 " p.eps = m->rms_norm_eps;\n"
1834 " p.rope_pos_offset = 0;\n"
1835 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1836 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1837 " p.input = current;\n"
1838 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1839 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1840 " p.wq = cptr_void(base, L->wq_offset);\n"
1841 " p.bq = cptr_f32(base, L->bq_offset);\n"
1842 " p.wk = cptr_void(base, L->wk_offset);\n"
1843 " p.bk = cptr_f32(base, L->bk_offset);\n"
1844 " p.wv = cptr_void(base, L->wv_offset);\n"
1845 " p.bv = cptr_f32(base, L->bv_offset);\n"
1846 " p.wo = cptr_void(base, L->wo_offset);\n"
1847 " p.bo = cptr_f32(base, L->bo_offset);\n"
1848 " p.w1 = cptr_void(base, L->w1_offset);\n"
1849 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1850 " p.w2 = cptr_void(base, L->w2_offset);\n"
1851 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1852 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1853 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1854 " p.q = ptr_f32(base, L->q_offset);\n"
1855 " p.k = ptr_f32(base, L->k_offset);\n"
1856 " p.v = ptr_f32(base, L->v_offset);\n"
1857 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1858 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1859 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1860 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1861 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1862 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1863 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1864 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1865 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1866 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1867 " p.output = ptr_f32(base, L->output_offset);\n"
1868 " p.wq_dtype = L->wq_dtype;\n"
1869 " p.wk_dtype = L->wk_dtype;\n"
1870 " p.wv_dtype = L->wv_dtype;\n"
1871 " p.wo_dtype = L->wo_dtype;\n"
1872 " p.w1_dtype = L->w1_dtype;\n"
1873 " p.w2_dtype = L->w2_dtype;\n"
1874 " ck_layer_forward_rmsnorm_swiglu_quant(&p);\n"
1875 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1876 " kv_cache_repack_head_major_inplace(p.k,\n"
1877 " p.num_kv_heads,\n"
1879 " m->kv_cache_capacity,\n"
1880 " p.aligned_head_dim);\n"
1881 " kv_cache_repack_head_major_inplace(p.v,\n"
1882 " p.num_kv_heads,\n"
1884 " m->kv_cache_capacity,\n"
1885 " p.aligned_head_dim);\n"
1887 " current = p.output;\n"
1889 " CKLayerForwardParams p = {0};\n"
1891 " p.embed_dim = m->embed_dim;\n"
1892 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
1893 " p.num_heads = m->num_attention_heads;\n"
1894 " p.num_kv_heads = m->num_kv_heads;\n"
1895 " p.head_dim = m->head_dim;\n"
1896 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
1897 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
1898 " p.intermediate_dim = m->intermediate_size;\n"
1899 " p.aligned_intermediate_dim = aligned_intermediate_dim;\n"
1900 " p.eps = m->rms_norm_eps;\n"
1901 " p.rope_pos_offset = 0;\n"
1902 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
1903 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
1904 " p.input = current;\n"
1905 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
1906 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
1907 " p.wq = cptr_f32(base, L->wq_offset);\n"
1908 " p.bq = cptr_f32(base, L->bq_offset);\n"
1909 " p.wk = cptr_f32(base, L->wk_offset);\n"
1910 " p.bk = cptr_f32(base, L->bk_offset);\n"
1911 " p.wv = cptr_f32(base, L->wv_offset);\n"
1912 " p.bv = cptr_f32(base, L->bv_offset);\n"
1913 " p.wo = cptr_f32(base, L->wo_offset);\n"
1914 " p.bo = cptr_f32(base, L->bo_offset);\n"
1915 " p.w1 = cptr_f32(base, L->w1_offset);\n"
1916 " p.b1 = cptr_f32(base, L->b1_offset);\n"
1917 " p.w2 = cptr_f32(base, L->w2_offset);\n"
1918 " p.b2 = cptr_f32(base, L->b2_offset);\n"
1919 " p.ln1_out = ptr_f32(base, L->ln1_out_offset);\n"
1920 " p.ln1_rstd = ptr_f32(base, L->ln1_rstd_offset);\n"
1921 " p.q = ptr_f32(base, L->q_offset);\n"
1922 " p.k = ptr_f32(base, L->k_offset);\n"
1923 " p.v = ptr_f32(base, L->v_offset);\n"
1924 " p.scores = L->scores_offset ? ptr_f32(base, L->scores_offset) : NULL;\n"
1925 " p.attn_out = ptr_f32(base, L->attn_out_offset);\n"
1926 " p.proj_tmp = ptr_f32(base, L->proj_tmp_offset);\n"
1927 " p.proj_scratch = ptr_f32(base, L->proj_scratch_offset);\n"
1928 " p.residual1 = ptr_f32(base, L->residual1_offset);\n"
1929 " p.ln2_out = ptr_f32(base, L->ln2_out_offset);\n"
1930 " p.ln2_rstd = ptr_f32(base, L->ln2_rstd_offset);\n"
1931 " p.fc1_out = ptr_f32(base, L->fc1_out_offset);\n"
1932 " p.swiglu_out = ptr_f32(base, L->swiglu_out_offset);\n"
1933 " p.mlp_out = ptr_f32(base, L->mlp_out_offset);\n"
1934 " p.output = ptr_f32(base, L->output_offset);\n"
1935 " ck_layer_forward_rmsnorm_swiglu(&p);\n"
1936 " if (m->kv_cache_enabled && !m->training_enabled) {\n"
1937 " kv_cache_repack_head_major_inplace(p.k,\n"
1938 " p.num_kv_heads,\n"
1940 " m->kv_cache_capacity,\n"
1941 " p.aligned_head_dim);\n"
1942 " kv_cache_repack_head_major_inplace(p.v,\n"
1943 " p.num_kv_heads,\n"
1945 " m->kv_cache_capacity,\n"
1946 " p.aligned_head_dim);\n"
1948 " current = p.output;\n"
1951 " float *final_out = ptr_f32(base, m->final_output_offset);\n"
1952 " rmsnorm_forward(current,\n"
1953 " cptr_f32(base, m->final_ln_weight_offset),\n"
1955 " ptr_f32(base, m->final_ln_rstd_offset),\n"
1958 " (int)m->aligned_embed_dim,\n"
1959 " m->rms_norm_eps);\n"
1960 " if (m->vocab_size > 0) {\n"
1961 " if (m->lm_head_weight_dtype == CK_DT_Q4_K) {\n"
1962 " gemm_nt_q4_k(final_out,\n"
1963 " cptr_void(base, m->lm_head_weight_offset),\n"
1965 " ptr_f32(base, m->logits_offset),\n"
1968 " (int)m->aligned_embed_dim);\n"
1969 " } else if (m->lm_head_weight_dtype == CK_DT_Q6_K) {\n"
1970 " gemm_nt_q6_k(final_out,\n"
1971 " cptr_void(base, m->lm_head_weight_offset),\n"
1973 " ptr_f32(base, m->logits_offset),\n"
1976 " (int)m->aligned_embed_dim);\n"
1978 " lm_head_forward(final_out,\n"
1979 " cptr_f32(base, m->lm_head_weight_offset),\n"
1980 " ptr_f32(base, m->logits_offset),\n"
1984 " (int)m->aligned_embed_dim);\n"
1993 "static int run_model_backward(TransformerModel *m,\n"
1994 " const int32_t *tokens,\n"
1995 " const int32_t *targets,\n"
1996 " float *loss_out)\n"
1998 " if (!m || !m->training_enabled) return 0;\n"
1999 " if (!tokens || !targets) return -1;\n"
2000 " if (m->num_layers <= 0) return -1;\n"
2001 " int T = m->active_tokens > 0 ? m->active_tokens : m->context_window;\n"
2002 " int V = m->vocab_size;\n"
2003 " int D = m->embed_dim;\n"
2004 " int aligned_D = (int)m->aligned_embed_dim;\n"
2005 " uint8_t *base = m->memory_base;\n"
2009 " float *final_out = ptr_f32(base, m->final_output_offset);\n"
2010 " float *logits = ptr_f32(base, m->logits_offset);\n"
2011 " float *d_logits = ptr_f32(base, m->d_logits_offset);\n"
2012 " float *d_final_out = ptr_f32(base, m->d_final_output_offset);\n"
2013 " float *d_final_in = ptr_f32(base, m->d_final_input_offset);\n"
2015 " float loss = 0.0f;\n"
2016 " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2017 " if (loss_out) {\n"
2018 " *loss_out = loss;\n"
2020 " lm_head_backward(final_out,\n"
2021 " cptr_f32(base, m->lm_head_weight_offset),\n"
2024 " ptr_f32(base, m->d_token_emb_offset),\n"
2025 " T, V, D, aligned_D);\n"
2026 " rmsnorm_backward(d_final_out,\n"
2027 " ptr_f32(base, m->layers[m->num_layers - 1].output_offset),\n"
2028 " cptr_f32(base, m->final_ln_weight_offset),\n"
2029 " ptr_f32(base, m->final_ln_rstd_offset),\n"
2031 " ptr_f32(base, m->d_final_ln_weight_offset),\n"
2032 " T, D, aligned_D);\n"
2034 " for (int layer = m->num_layers - 1; layer >= 0; --layer) {\n"
2035 " TrulyOptimalLayer *L = &m->layers[layer];\n"
2036 " CKLayerBackwardParams p = {0};\n"
2038 " p.embed_dim = m->embed_dim;\n"
2039 " p.aligned_embed_dim = (int)m->aligned_embed_dim;\n"
2040 " p.num_heads = m->num_attention_heads;\n"
2041 " p.num_kv_heads = m->num_kv_heads;\n"
2042 " p.head_dim = m->head_dim;\n"
2043 " p.aligned_head_dim = (int)m->aligned_head_dim;\n"
2044 " p.aligned_context_window = (int)m->aligned_attn_context_window;\n"
2045 " p.intermediate_dim = m->intermediate_size;\n"
2046 " p.aligned_intermediate_dim = (int)align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2047 " p.eps = m->rms_norm_eps;\n"
2048 " p.rope_pos_offset = 0;\n"
2049 " p.rope_cos = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_cos_cache_offset) : NULL;\n"
2050 " p.rope_sin = (m->rope_theta > 0.0f) ? cptr_f32(base, m->rope_sin_cache_offset) : NULL;\n"
2051 " p.input = (layer == 0) ? ptr_f32(base, m->embedded_input_offset)\n"
2052 " : ptr_f32(base, m->layers[layer - 1].output_offset);\n"
2053 " p.ln1_gamma = cptr_f32(base, L->ln1_gamma_offset);\n"
2054 " p.ln2_gamma = cptr_f32(base, L->ln2_gamma_offset);\n"
2055 " p.ln1_out = cptr_f32(base, L->ln1_out_offset);\n"
2056 " p.ln1_rstd = cptr_f32(base, L->ln1_rstd_offset);\n"
2057 " p.ln2_out = cptr_f32(base, L->ln2_out_offset);\n"
2058 " p.ln2_rstd = cptr_f32(base, L->ln2_rstd_offset);\n"
2059 " p.wq = cptr_f32(base, L->wq_offset);\n"
2060 " p.bq = cptr_f32(base, L->bq_offset);\n"
2061 " p.wk = cptr_f32(base, L->wk_offset);\n"
2062 " p.bk = cptr_f32(base, L->bk_offset);\n"
2063 " p.wv = cptr_f32(base, L->wv_offset);\n"
2064 " p.bv = cptr_f32(base, L->bv_offset);\n"
2065 " p.wo = cptr_f32(base, L->wo_offset);\n"
2066 " p.bo = cptr_f32(base, L->bo_offset);\n"
2067 " p.w1 = cptr_f32(base, L->w1_offset);\n"
2068 " p.b1 = cptr_f32(base, L->b1_offset);\n"
2069 " p.w2 = cptr_f32(base, L->w2_offset);\n"
2070 " p.b2 = cptr_f32(base, L->b2_offset);\n"
2071 " p.q = cptr_f32(base, L->q_offset);\n"
2072 " p.k = cptr_f32(base, L->k_offset);\n"
2073 " p.v = cptr_f32(base, L->v_offset);\n"
2074 " p.scores = L->scores_offset ? cptr_f32(base, L->scores_offset) : NULL;\n"
2075 " p.attn_out = cptr_f32(base, L->attn_out_offset);\n"
2076 " p.residual1 = cptr_f32(base, L->residual1_offset);\n"
2077 " p.fc1_out = cptr_f32(base, L->fc1_out_offset);\n"
2078 " p.swiglu_out = cptr_f32(base, L->swiglu_out_offset);\n"
2079 " p.d_output = ptr_f32(base, L->d_output_offset);\n"
2080 " p.d_input = ptr_f32(base, L->d_input_offset);\n"
2081 " p.d_ln1_gamma = ptr_f32(base, L->d_ln1_gamma_offset);\n"
2082 " p.d_ln2_gamma = ptr_f32(base, L->d_ln2_gamma_offset);\n"
2083 " p.d_wq = ptr_f32(base, L->d_wq_offset);\n"
2084 " p.d_bq = ptr_f32(base, L->d_bq_offset);\n"
2085 " p.d_wk = ptr_f32(base, L->d_wk_offset);\n"
2086 " p.d_bk = ptr_f32(base, L->d_bk_offset);\n"
2087 " p.d_wv = ptr_f32(base, L->d_wv_offset);\n"
2088 " p.d_bv = ptr_f32(base, L->d_bv_offset);\n"
2089 " p.d_wo = ptr_f32(base, L->d_wo_offset);\n"
2090 " p.d_bo = ptr_f32(base, L->d_bo_offset);\n"
2091 " p.d_w1 = ptr_f32(base, L->d_w1_offset);\n"
2092 " p.d_b1 = ptr_f32(base, L->d_b1_offset);\n"
2093 " p.d_w2 = ptr_f32(base, L->d_w2_offset);\n"
2094 " p.d_b2 = ptr_f32(base, L->d_b2_offset);\n"
2095 " p.d_ln1_out = ptr_f32(base, L->d_ln1_out_offset);\n"
2096 " p.d_q = ptr_f32(base, L->d_q_offset);\n"
2097 " p.d_k = ptr_f32(base, L->d_k_offset);\n"
2098 " p.d_v = ptr_f32(base, L->d_v_offset);\n"
2099 " p.d_scores = ptr_f32(base, L->d_scores_offset);\n"
2100 " p.d_attn_out = ptr_f32(base, L->d_attn_out_offset);\n"
2101 " p.d_proj_tmp = ptr_f32(base, L->d_proj_tmp_offset);\n"
2102 " p.d_residual1 = ptr_f32(base, L->d_residual1_offset);\n"
2103 " p.d_ln2_out = ptr_f32(base, L->d_ln2_out_offset);\n"
2104 " p.d_fc1_out = ptr_f32(base, L->d_fc1_out_offset);\n"
2105 " p.d_swiglu_out = ptr_f32(base, L->d_swiglu_out_offset);\n"
2106 " p.d_mlp_out = ptr_f32(base, L->d_mlp_out_offset);\n"
2108 " const float *src = (layer == m->num_layers - 1)\n"
2110 " : ptr_f32(base, m->layers[layer + 1].d_input_offset);\n"
2111 " memcpy(p.d_output, src, (size_t)T * (size_t)aligned_D * sizeof(float));\n"
2113 " ck_layer_backward_rmsnorm_swiglu(&p);\n"
2117 " TrulyOptimalLayer *L0 = &m->layers[0];\n"
2118 " embedding_backward(tokens,\n"
2120 " ptr_f32(base, L0->d_input_offset),\n"
2121 " ptr_f32(base, m->d_token_emb_offset),\n"
2122 " ptr_f32(base, m->d_pos_emb_offset),\n"
2126 " m->context_window,\n"
2127 " m->rope_theta <= 0.0f);\n"
2130 " /* SGD update is now called separately via optimizer_step() */\n"
2135 "static int parse_int_arg(const char *s, int *out)\n"
2137 " if (!s || !out) return 0;\n"
2138 " char *end = NULL;\n"
2139 " long v = strtol(s, &end, 10);\n"
2140 " if (!end || *end != '\\0') return 0;\n"
2144 "static int parse_float_arg(const char *s, float *out)\n"
2146 " if (!s || !out) return 0;\n"
2147 " char *end = NULL;\n"
2148 " double v = strtod(s, &end);\n"
2149 " if (!end || *end != '\\0') return 0;\n"
2150 " *out = (float)v;\n"
2153 "static void print_usage(const char *prog)\n"
2155 " printf(\"Usage: %%s [options]\\n\", prog);\n"
2156 " printf(\" --dump Print layout summary (layer 0 only)\\n\");\n"
2157 " printf(\" --dump-all Print layout summary for all layers\\n\");\n"
2158 " printf(\" --no-forward Skip forward pass (layout + alloc only)\\n\");\n"
2159 " printf(\" --layers N Override num_layers\\n\");\n"
2160 " printf(\" --embed N Override embed_dim\\n\");\n"
2161 " printf(\" --intermediate N Override intermediate_size\\n\");\n"
2162 " printf(\" --heads N Override num_attention_heads\\n\");\n"
2163 " printf(\" --kv-heads N Override num_kv_heads\\n\");\n"
2164 " printf(\" --vocab N Override vocab_size\\n\");\n"
2165 " printf(\" --ctx N Override context_window\\n\");\n"
2166 " printf(\" --cores N Override num_cores\\n\");\n"
2167 " printf(\" --litmus Run LM head + CE + backward litmus\\n\");\n"
2168 " printf(\" --backward Run backward pass + SGD update (requires --tokens/--targets)\\n\");\n"
2169 " printf(\" --lr F SGD learning rate (default: 1e-3 when --backward)\\n\");\n"
2170 " printf(\" --steps N Training steps (default: 1)\\n\");\n"
2171 " printf(\" --log-steps Print loss per step during training\\n\");\n"
2172 " printf(\" --strict Enable strict parity mode (single-thread + double GEMM)\\n\");\n"
2173 " printf(\" --hidden PATH Load hidden activations [T x aligned_D] f32\\n\");\n"
2174 " printf(\" --weights PATH Load LM head weights [V x aligned_D] f32 (litmus)\\n\");\n"
2175 " printf(\" --targets PATH Load target tokens [T] int32\\n\");\n"
2176 " printf(\" --model-weights PATH Load full model weights (bump format)\\n\");\n"
2177 " printf(\" --tokens PATH Load token IDs [T] int32 and build embeddings\\n\");\n"
2178 " printf(\" --out-logits PATH Write logits [T x V] f32\\n\");\n"
2179 " printf(\" --out-dlogits PATH Write d_logits [T x V] f32\\n\");\n"
2180 " printf(\" --out-dhidden PATH Write d_hidden [T x aligned_D] f32\\n\");\n"
2181 " printf(\" --out-dweights PATH Write d_weights [V x aligned_D] f32\\n\");\n"
2182 " printf(\" --out-loss PATH Write loss (single f32)\\n\");\n"
2183 " printf(\" --out-weights PATH Write model weights (flat, no header)\\n\");\n"
2184 " printf(\" --help Show this help\\n\");\n"
2186 "static int read_floats(const char *path, float *dst, size_t count)\n"
2188 " if (!path || !dst) return -1;\n"
2189 " FILE *f = fopen(path, \"rb\");\n"
2191 " perror(\"fopen\");\n"
2194 " size_t got = fread(dst, sizeof(float), count, f);\n"
2196 " return got == count ? 0 : -1;\n"
2198 "static int read_ints(const char *path, int32_t *dst, size_t count)\n"
2200 " if (!path || !dst) return -1;\n"
2201 " FILE *f = fopen(path, \"rb\");\n"
2203 " perror(\"fopen\");\n"
2206 " size_t got = fread(dst, sizeof(int32_t), count, f);\n"
2208 " return got == count ? 0 : -1;\n"
2210 "static int read_floats_file(FILE *f, float *dst, size_t count)\n"
2212 " if (!f || !dst) return -1;\n"
2213 " size_t got = fread(dst, sizeof(float), count, f);\n"
2214 " return got == count ? 0 : -1;\n"
2216 "static int read_bytes_file(FILE *f, void *dst, size_t bytes)\n"
2218 " if (!f || !dst) return -1;\n"
2219 " size_t got = fread(dst, 1, bytes, f);\n"
2220 " return got == bytes ? 0 : -1;\n"
2222 "static int write_floats_file(FILE *f, const float *src, size_t count)\n"
2224 " if (!f || !src) return -1;\n"
2225 " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2226 " return wrote == count ? 0 : -1;\n"
2228 "static int write_bytes_file(FILE *f, const void *src, size_t bytes)\n"
2230 " if (!f || !src) return -1;\n"
2231 " size_t wrote = fwrite(src, 1, bytes, f);\n"
2232 " return wrote == bytes ? 0 : -1;\n"
2234 "static int read_weight_file(FILE *f, CKDataType dtype, void *dst, size_t n_elements)\n"
2236 " if (!f || !dst) return -1;\n"
2237 " if (dtype == CK_DT_FP32) {\n"
2238 " return read_floats_file(f, (float *)dst, n_elements);\n"
2240 " return read_bytes_file(f, dst, ck_dtype_row_bytes(dtype, n_elements));\n"
2242 "static int write_weight_file(FILE *f, CKDataType dtype, const void *src, size_t n_elements)\n"
2244 " if (!f || !src) return -1;\n"
2245 " if (dtype == CK_DT_FP32) {\n"
2246 " return write_floats_file(f, (const float *)src, n_elements);\n"
2248 " return write_bytes_file(f, src, ck_dtype_row_bytes(dtype, n_elements));\n"
2250 "static int skip_bump_header(FILE *f)\n"
2252 " if (!f) return -1;\n"
2254 " if (fread(magic, 1, 8, f) != 8) return -1;\n"
2255 " if (memcmp(magic, \"BUMPWGT3\", 8) == 0) {\n"
2256 " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2257 " uint32_t dtype_len = 0;\n"
2258 " if (fread(&dtype_len, sizeof(uint32_t), 1, f) != 1) return -1;\n"
2259 " if (fseek(f, (long)dtype_len, SEEK_CUR) != 0) return -1;\n"
2262 " if (memcmp(magic, \"BUMPWGT2\", 8) == 0) {\n"
2263 " if (fseek(f, 128, SEEK_SET) != 0) return -1;\n"
2266 " if (fseek(f, 0, SEEK_SET) != 0) return -1;\n"
2269 "static int load_model_weights(const char *path, TransformerModel *m)\n"
2271 " if (!path || !m || !m->memory_base) return -1;\n"
2272 " FILE *f = fopen(path, \"rb\");\n"
2274 " perror(\"fopen\");\n"
2277 " if (skip_bump_header(f) < 0) {\n"
2281 " uint8_t *base = m->memory_base;\n"
2282 " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2283 " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2284 " if (read_weight_file(f, m->token_emb_dtype, ptr_u8(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2285 " if (read_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2286 " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2288 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2289 " TrulyOptimalLayer *L = &m->layers[layer];\n"
2290 " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2291 " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2292 " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2293 " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2294 " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2295 " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2296 " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2297 " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2299 " if (read_floats_file(f, ptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2300 " if (read_floats_file(f, ptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2301 " if (read_weight_file(f, L->wq_dtype, ptr_u8(base, L->wq_offset), q_w) != 0) goto fail;\n"
2302 " if (read_floats_file(f, ptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2303 " if (read_weight_file(f, L->wk_dtype, ptr_u8(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2304 " if (read_floats_file(f, ptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2305 " if (read_weight_file(f, L->wv_dtype, ptr_u8(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2306 " if (read_floats_file(f, ptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2307 " if (read_weight_file(f, L->wo_dtype, ptr_u8(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2308 " if (read_floats_file(f, ptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2309 " if (read_weight_file(f, L->w1_dtype, ptr_u8(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2310 " if (read_floats_file(f, ptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2311 " if (read_weight_file(f, L->w2_dtype, ptr_u8(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2312 " if (read_floats_file(f, ptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2315 " if (read_floats_file(f, ptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2316 " if (read_floats_file(f, ptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2324 "static int save_model_weights(const char *path, const TransformerModel *m)\n"
2326 " if (!path || !m || !m->memory_base) return -1;\n"
2327 " FILE *f = fopen(path, \"wb\");\n"
2329 " perror(\"fopen\");\n"
2332 " uint8_t *base = m->memory_base;\n"
2333 " size_t aligned_intermediate = align_up_elems((size_t)m->intermediate_size, m->elem_bytes, CACHELINE_BYTES);\n"
2334 " size_t tok_elems = (size_t)m->vocab_size * m->aligned_embed_dim;\n"
2335 " if (write_weight_file(f, m->token_emb_dtype, cptr_void(base, m->token_emb_offset), tok_elems) != 0) goto fail;\n"
2336 " if (write_floats_file(f, ptr_f32(base, m->pos_emb_offset),\n"
2337 " (size_t)m->context_window * m->aligned_embed_dim) != 0) goto fail;\n"
2339 " for (int layer = 0; layer < m->num_layers; ++layer) {\n"
2340 " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2341 " size_t head_w_stride = m->aligned_head_dim * m->aligned_embed_dim;\n"
2342 " size_t q_w = (size_t)m->num_attention_heads * head_w_stride;\n"
2343 " size_t kv_w = (size_t)m->num_kv_heads * head_w_stride;\n"
2344 " size_t q_b = (size_t)m->num_attention_heads * m->aligned_head_dim;\n"
2345 " size_t kv_b = (size_t)m->num_kv_heads * m->aligned_head_dim;\n"
2346 " size_t wo_w = (size_t)m->num_attention_heads * m->aligned_embed_dim * m->aligned_head_dim;\n"
2347 " size_t w1_w = (size_t)(2 * aligned_intermediate) * m->aligned_embed_dim;\n"
2348 " size_t w2_w = m->aligned_embed_dim * aligned_intermediate;\n"
2350 " if (write_floats_file(f, cptr_f32(base, L->ln1_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2351 " if (write_floats_file(f, cptr_f32(base, L->ln2_gamma_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2352 " if (write_weight_file(f, L->wq_dtype, cptr_void(base, L->wq_offset), q_w) != 0) goto fail;\n"
2353 " if (write_floats_file(f, cptr_f32(base, L->bq_offset), q_b) != 0) goto fail;\n"
2354 " if (write_weight_file(f, L->wk_dtype, cptr_void(base, L->wk_offset), kv_w) != 0) goto fail;\n"
2355 " if (write_floats_file(f, cptr_f32(base, L->bk_offset), kv_b) != 0) goto fail;\n"
2356 " if (write_weight_file(f, L->wv_dtype, cptr_void(base, L->wv_offset), kv_w) != 0) goto fail;\n"
2357 " if (write_floats_file(f, cptr_f32(base, L->bv_offset), kv_b) != 0) goto fail;\n"
2358 " if (write_weight_file(f, L->wo_dtype, cptr_void(base, L->wo_offset), wo_w) != 0) goto fail;\n"
2359 " if (write_floats_file(f, cptr_f32(base, L->bo_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2360 " if (write_weight_file(f, L->w1_dtype, cptr_void(base, L->w1_offset), w1_w) != 0) goto fail;\n"
2361 " if (write_floats_file(f, cptr_f32(base, L->b1_offset), (size_t)(2 * aligned_intermediate)) != 0) goto fail;\n"
2362 " if (write_weight_file(f, L->w2_dtype, cptr_void(base, L->w2_offset), w2_w) != 0) goto fail;\n"
2363 " if (write_floats_file(f, cptr_f32(base, L->b2_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2366 " if (write_floats_file(f, cptr_f32(base, m->final_ln_weight_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2367 " if (write_floats_file(f, cptr_f32(base, m->final_ln_bias_offset), m->aligned_embed_dim) != 0) goto fail;\n"
2375 "static void embed_tokens(const TransformerModel *m, const int32_t *tokens, int token_count)\n"
2377 " if (!m || !m->memory_base || !tokens) return;\n"
2378 " const uint8_t *base = m->memory_base;\n"
2379 " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2380 " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2381 " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2382 " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2383 " int T = m->context_window;\n"
2384 " int D = m->embed_dim;\n"
2385 " int aligned_D = (int)m->aligned_embed_dim;\n"
2386 " for (int t = 0; t < T; ++t) {\n"
2387 " float *dst = out + (size_t)t * aligned_D;\n"
2388 " if (t < token_count) {\n"
2389 " int id = tokens[t];\n"
2390 " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2391 " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2392 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2393 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2394 " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2395 " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2396 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2397 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2398 " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2400 " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2401 " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2403 " if (aligned_D > D) {\n"
2404 " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2406 " if (m->rope_theta <= 0.0f) {\n"
2407 " const float *p = pos + (size_t)t * aligned_D;\n"
2408 " for (int d = 0; d < D; ++d) {\n"
2409 " dst[d] += p[d];\n"
2413 " memset(dst, 0, (size_t)aligned_D * sizeof(float));\n"
2417 "static void embed_token_at(const TransformerModel *m, int32_t token, int t)\n"
2419 " if (!m || !m->memory_base) return;\n"
2420 " if (t < 0 || t >= m->context_window) return;\n"
2421 " const uint8_t *base = m->memory_base;\n"
2422 " float *out = ptr_f32((uint8_t *)base, m->embedded_input_offset);\n"
2423 " const float *tok_f32 = cptr_f32(base, m->token_emb_offset);\n"
2424 " const uint8_t *tok_q = (const uint8_t *)cptr_void(base, m->token_emb_offset);\n"
2425 " const float *pos = cptr_f32(base, m->pos_emb_offset);\n"
2426 " int D = m->embed_dim;\n"
2427 " int aligned_D = (int)m->aligned_embed_dim;\n"
2428 " int id = (int)token;\n"
2429 " if (id < 0 || id >= m->vocab_size) id = 0;\n"
2430 " float *dst = out + (size_t)t * aligned_D;\n"
2431 " if (m->token_emb_dtype == CK_DT_Q4_K) {\n"
2432 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_D);\n"
2433 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2434 " dequant_q4_k_row(row, dst, (size_t)aligned_D);\n"
2435 " } else if (m->token_emb_dtype == CK_DT_Q6_K) {\n"
2436 " size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_D);\n"
2437 " const void *row = tok_q + (size_t)id * row_bytes;\n"
2438 " dequant_q6_k_row(row, dst, (size_t)aligned_D);\n"
2440 " const float *src = tok_f32 + (size_t)id * aligned_D;\n"
2441 " memcpy(dst, src, (size_t)D * sizeof(float));\n"
2443 " if (aligned_D > D) {\n"
2444 " memset(dst + D, 0, (size_t)(aligned_D - D) * sizeof(float));\n"
2446 " if (m->rope_theta <= 0.0f) {\n"
2447 " const float *p = pos + (size_t)t * aligned_D;\n"
2448 " for (int d = 0; d < D; ++d) {\n"
2449 " dst[d] += p[d];\n"
2453 "static int write_floats(const char *path, const float *src, size_t count)\n"
2455 " if (!path || !src) return -1;\n"
2456 " FILE *f = fopen(path, \"wb\");\n"
2458 " perror(\"fopen\");\n"
2461 " size_t wrote = fwrite(src, sizeof(float), count, f);\n"
2463 " return wrote == count ? 0 : -1;\n"
2465 "static int write_float_scalar(const char *path, float v)\n"
2467 " if (!path) return -1;\n"
2468 " FILE *f = fopen(path, \"wb\");\n"
2470 " perror(\"fopen\");\n"
2473 " size_t wrote = fwrite(&v, sizeof(float), 1, f);\n"
2475 " return wrote == 1 ? 0 : -1;\n"
2477 "static void lm_head_forward(const float *hidden,\n"
2478 " const float *weights,\n"
2480 " int T, int V, int D, int aligned_D)\n"
2482 " for (int t = 0; t < T; ++t) {\n"
2483 " const float *h = hidden + (size_t)t * aligned_D;\n"
2484 " float *out = logits + (size_t)t * V;\n"
2485 " for (int v = 0; v < V; ++v) {\n"
2486 " const float *w = weights + (size_t)v * aligned_D;\n"
2487 " float sum = 0.0f;\n"
2488 " for (int d = 0; d < D; ++d) {\n"
2489 " sum += h[d] * w[d];\n"
2495 "static void softmax_cross_entropy(const float *logits,\n"
2496 " const int32_t *targets,\n"
2498 " float *d_logits,\n"
2499 " float *loss_out)\n"
2501 " double total = 0.0;\n"
2502 " for (int t = 0; t < T; ++t) {\n"
2503 " const float *row = logits + (size_t)t * V;\n"
2504 " float *drow = d_logits + (size_t)t * V;\n"
2505 " int target = targets[t];\n"
2506 " float max_logit = row[0];\n"
2507 " for (int v = 1; v < V; ++v) {\n"
2508 " if (row[v] > max_logit) max_logit = row[v];\n"
2510 " double sum_exp = 0.0;\n"
2511 " for (int v = 0; v < V; ++v) {\n"
2512 " drow[v] = expf(row[v] - max_logit);\n"
2513 " sum_exp += drow[v];\n"
2515 " float inv_sum = 1.0f / (float)sum_exp;\n"
2516 " for (int v = 0; v < V; ++v) {\n"
2517 " drow[v] *= inv_sum;\n"
2519 " double logsum = (double)max_logit + log(sum_exp);\n"
2520 " total += logsum - (double)row[target];\n"
2521 " drow[target] -= 1.0f;\n"
2522 " float scale = 1.0f / (float)T;\n"
2523 " for (int v = 0; v < V; ++v) {\n"
2524 " drow[v] *= scale;\n"
2527 " if (loss_out) {\n"
2528 " *loss_out = (float)(total / (double)T);\n"
2531 "static void lm_head_backward(const float *hidden,\n"
2532 " const float *weights,\n"
2533 " const float *d_logits,\n"
2534 " float *d_hidden,\n"
2535 " float *d_weights,\n"
2536 " int T, int V, int D, int aligned_D)\n"
2538 " size_t dh_count = (size_t)T * aligned_D;\n"
2539 " size_t dw_count = (size_t)V * aligned_D;\n"
2540 " for (size_t i = 0; i < dh_count; ++i) d_hidden[i] = 0.0f;\n"
2541 " for (size_t i = 0; i < dw_count; ++i) d_weights[i] = 0.0f;\n"
2542 " for (int t = 0; t < T; ++t) {\n"
2543 " const float *dlog = d_logits + (size_t)t * V;\n"
2544 " for (int d = 0; d < D; ++d) {\n"
2545 " double sum = 0.0;\n"
2546 " for (int v = 0; v < V; ++v) {\n"
2547 " sum += (double)dlog[v] * (double)weights[(size_t)v * aligned_D + d];\n"
2549 " d_hidden[(size_t)t * aligned_D + d] = (float)sum;\n"
2552 " for (int v = 0; v < V; ++v) {\n"
2553 " float *dw = d_weights + (size_t)v * aligned_D;\n"
2554 " for (int d = 0; d < D; ++d) {\n"
2555 " double sum = 0.0;\n"
2556 " for (int t = 0; t < T; ++t) {\n"
2557 " sum += (double)d_logits[(size_t)t * V + v] * (double)hidden[(size_t)t * aligned_D + d];\n"
2559 " dw[d] = (float)sum;\n"
2565 "static void dump_layer_offsets(const TransformerModel *m, int layer)\n"
2567 " const TrulyOptimalLayer *L = &m->layers[layer];\n"
2568 " printf(\"Layer %%d offsets (bytes):\\n\", layer);\n"
2569 " printf(\" ln1_gamma=%%zu ln2_gamma=%%zu wq=%%zu wk=%%zu wv=%%zu wo=%%zu w1=%%zu w2=%%zu\\n\",\n"
2570 " L->ln1_gamma_offset, L->ln2_gamma_offset, L->wq_offset, L->wk_offset,\n"
2571 " L->wv_offset, L->wo_offset, L->w1_offset, L->w2_offset);\n"
2572 " printf(\" ln1_out=%%zu q=%%zu k=%%zu v=%%zu scores=%%zu attn_out=%%zu\\n\",\n"
2573 " L->ln1_out_offset, L->q_offset, L->k_offset, L->v_offset,\n"
2574 " L->scores_offset, L->attn_out_offset);\n"
2575 " printf(\" proj_tmp=%%zu residual1=%%zu ln2_out=%%zu fc1_out=%%zu swiglu_out=%%zu mlp_out=%%zu output=%%zu\\n\",\n"
2576 " L->proj_tmp_offset, L->residual1_offset, L->ln2_out_offset,\n"
2577 " L->fc1_out_offset, L->swiglu_out_offset, L->mlp_out_offset, L->output_offset);\n"
2579 "static void dump_layout(const TransformerModel *m, int dump_all)\n"
2581 " size_t bytes = m->total_bytes;\n"
2582 " printf(\"Model config:\\n\");\n"
2583 " printf(\" layers=%%d embed=%%d intermediate=%%d heads=%%d kv_heads=%%d\\n\",\n"
2584 " m->num_layers, m->embed_dim, m->intermediate_size, m->num_attention_heads, m->num_kv_heads);\n"
2585 " printf(\" head_dim=%%d vocab=%%d ctx=%%d cores=%%d\\n\",\n"
2586 " m->head_dim, m->vocab_size, m->context_window, m->num_cores);\n"
2587 " printf(\" eps=%%.6g rope_theta=%%.6g\\n\", m->rms_norm_eps, m->rope_theta);\n"
2588 " printf(\"Aligned dims (elements): embed=%%zu head=%%zu ctx=%%zu\\n\",\n"
2589 " m->aligned_embed_dim, m->aligned_head_dim, m->aligned_attn_context_window);\n"
2590 " printf(\"Memory: total_bytes=%%zu\\n\", bytes);\n"
2591 " printf(\"Global offsets (bytes): token=%%zu pos=%%zu embedded=%%zu layers_start=%%zu\\n\",\n"
2592 " m->token_emb_offset, m->pos_emb_offset, m->embedded_input_offset, m->layers_start_offset);\n"
2593 " printf(\"Final offsets (bytes): final_ln_w=%%zu final_ln_b=%%zu final_ln_mean=%%zu final_ln_rstd=%%zu\\n\",\n"
2594 " m->final_ln_weight_offset, m->final_ln_bias_offset,\n"
2595 " m->final_ln_mean_offset, m->final_ln_rstd_offset);\n"
2596 " printf(\"LM/logits offsets (bytes): lm_head=%%zu logits=%%zu\\n\",\n"
2597 " m->lm_head_weight_offset, m->logits_offset);\n"
2598 " if (m->num_layers > 0) {\n"
2599 " dump_layer_offsets(m, 0);\n"
2600 " if (dump_all) {\n"
2601 " for (int i = 1; i < m->num_layers; ++i) {\n"
2602 " dump_layer_offsets(m, i);\n"
2611 "int main(int argc, char **argv)\n"
2614 " int dump_all = 0;\n"
2615 " int no_forward = 0;\n"
2616 " int run_litmus = 0;\n"
2617 " int run_backward = 0;\n"
2618 " const char *litmus_hidden = NULL;\n"
2619 " const char *litmus_weights = NULL;\n"
2620 " const char *litmus_targets = NULL;\n"
2621 " const char *model_weights = NULL;\n"
2622 " const char *tokens_path = NULL;\n"
2623 " const char *out_logits = NULL;\n"
2624 " const char *out_dlogits = NULL;\n"
2625 " const char *out_dhidden = NULL;\n"
2626 " const char *out_dweights = NULL;\n"
2627 " const char *out_loss = NULL;\n"
2628 " const char *out_weights = NULL;\n"
2630 " int log_steps = 0;\n"
2631 " int strict = 0;\n"
2632 " int32_t *tokens = NULL;\n"
2633 " int32_t *targets = NULL;\n"
2634 " TransformerModel m = {0};\n"
2635 " memcpy(m.magic, \"BUMPWGT3\", 8);\n"
2637 " m.model_type = 0;\n"
2638 " m.num_layers = %d;\n"
2639 " m.embed_dim = %d;\n"
2640 " m.intermediate_size = %d;\n"
2641 " m.num_attention_heads = %d;\n"
2642 " m.num_kv_heads = %d;\n"
2643 " m.vocab_size = %d;\n"
2644 " m.context_window = %d;\n"
2645 " m.rms_norm_eps = %.9g;\n"
2646 " m.rope_theta = %.9g;\n"
2647 " m.num_cores = 1;\n"
2648 " m.task_type = TASK_LM;\n"
2649 " m.optimizer = OPTIMIZER_SGD;\n"
2650 " m.learning_rate = 0.0f;\n"
2651 " for (int i = 1; i < argc; ++i) {\n"
2652 " if (strcmp(argv[i], \"--dump\") == 0) {\n"
2656 " if (strcmp(argv[i], \"--dump-all\") == 0) {\n"
2661 " if (strcmp(argv[i], \"--no-forward\") == 0) {\n"
2662 " no_forward = 1;\n"
2665 " if (strcmp(argv[i], \"--strict\") == 0) {\n"
2669 " if (strcmp(argv[i], \"--litmus\") == 0) {\n"
2670 " run_litmus = 1;\n"
2673 " if (strcmp(argv[i], \"--backward\") == 0) {\n"
2674 " run_backward = 1;\n"
2677 " if (strcmp(argv[i], \"--lr\") == 0 && i + 1 < argc) {\n"
2678 " parse_float_arg(argv[++i], &m.learning_rate);\n"
2681 " if (strcmp(argv[i], \"--help\") == 0) {\n"
2682 " print_usage(argv[0]);\n"
2685 " if (strcmp(argv[i], \"--hidden\") == 0 && i + 1 < argc) {\n"
2686 " litmus_hidden = argv[++i];\n"
2689 " if (strcmp(argv[i], \"--weights\") == 0 && i + 1 < argc) {\n"
2690 " litmus_weights = argv[++i];\n"
2693 " if (strcmp(argv[i], \"--targets\") == 0 && i + 1 < argc) {\n"
2694 " litmus_targets = argv[++i];\n"
2697 " if (strcmp(argv[i], \"--model-weights\") == 0 && i + 1 < argc) {\n"
2698 " model_weights = argv[++i];\n"
2701 " if (strcmp(argv[i], \"--tokens\") == 0 && i + 1 < argc) {\n"
2702 " tokens_path = argv[++i];\n"
2705 " if (strcmp(argv[i], \"--out-logits\") == 0 && i + 1 < argc) {\n"
2706 " out_logits = argv[++i];\n"
2709 " if (strcmp(argv[i], \"--out-dlogits\") == 0 && i + 1 < argc) {\n"
2710 " out_dlogits = argv[++i];\n"
2713 " if (strcmp(argv[i], \"--out-dhidden\") == 0 && i + 1 < argc) {\n"
2714 " out_dhidden = argv[++i];\n"
2717 " if (strcmp(argv[i], \"--out-dweights\") == 0 && i + 1 < argc) {\n"
2718 " out_dweights = argv[++i];\n"
2721 " if (strcmp(argv[i], \"--out-loss\") == 0 && i + 1 < argc) {\n"
2722 " out_loss = argv[++i];\n"
2725 " if (strcmp(argv[i], \"--out-weights\") == 0 && i + 1 < argc) {\n"
2726 " out_weights = argv[++i];\n"
2729 " if (strcmp(argv[i], \"--steps\") == 0 && i + 1 < argc) {\n"
2730 " parse_int_arg(argv[++i], &steps);\n"
2733 " if (strcmp(argv[i], \"--log-steps\") == 0) {\n"
2737 " if (strcmp(argv[i], \"--layers\") == 0 && i + 1 < argc) {\n"
2738 " parse_int_arg(argv[++i], &m.num_layers);\n"
2741 " if (strcmp(argv[i], \"--embed\") == 0 && i + 1 < argc) {\n"
2742 " parse_int_arg(argv[++i], &m.embed_dim);\n"
2745 " if (strcmp(argv[i], \"--intermediate\") == 0 && i + 1 < argc) {\n"
2746 " parse_int_arg(argv[++i], &m.intermediate_size);\n"
2749 " if (strcmp(argv[i], \"--heads\") == 0 && i + 1 < argc) {\n"
2750 " parse_int_arg(argv[++i], &m.num_attention_heads);\n"
2753 " if (strcmp(argv[i], \"--kv-heads\") == 0 && i + 1 < argc) {\n"
2754 " parse_int_arg(argv[++i], &m.num_kv_heads);\n"
2757 " if (strcmp(argv[i], \"--vocab\") == 0 && i + 1 < argc) {\n"
2758 " parse_int_arg(argv[++i], &m.vocab_size);\n"
2761 " if (strcmp(argv[i], \"--ctx\") == 0 && i + 1 < argc) {\n"
2762 " parse_int_arg(argv[++i], &m.context_window);\n"
2765 " if (strcmp(argv[i], \"--cores\") == 0 && i + 1 < argc) {\n"
2766 " parse_int_arg(argv[++i], &m.num_cores);\n"
2769 " fprintf(stderr, \"Unknown or invalid arg: %%s\\n\", argv[i]);\n"
2770 " print_usage(argv[0]);\n"
2774 " ck_set_strict_parity(1);\n"
2776 " if (run_backward && m.learning_rate == 0.0f) {\n"
2777 " m.learning_rate = 1e-3f;\n"
2779 " m.training_enabled = run_backward;\n"
2780 " m.weight_dtype = CK_DT_FP32;\n"
2782 " const char *wd = getenv(\"CK_WEIGHT_DTYPE\");\n"
2784 " if (strcmp(wd, \"q4_k\") == 0 || strcmp(wd, \"q4_k_m\") == 0 ||\n"
2785 " strcmp(wd, \"Q4_K\") == 0 || strcmp(wd, \"Q4_K_M\") == 0) {\n"
2786 " m.weight_dtype = CK_DT_Q4_K;\n"
2787 " } else if (strcmp(wd, \"q6_k\") == 0 || strcmp(wd, \"q6_k_l\") == 0 ||\n"
2788 " strcmp(wd, \"Q6_K\") == 0 || strcmp(wd, \"Q6_K_L\") == 0) {\n"
2789 " m.weight_dtype = CK_DT_Q6_K;\n"
2793 " init_weight_dtypes_uniform(&m, m.weight_dtype);\n"
2794 " refresh_weight_flags(&m);\n"
2795 " if (model_weights) {\n"
2796 " int dtype_rc = load_weight_dtypes(model_weights, &m);\n"
2797 " if (dtype_rc < 0) {\n"
2798 " fprintf(stderr, \"failed to read weight dtype table\\n\");\n"
2802 " if (m.training_enabled && m.weights_quantized) {\n"
2803 " fprintf(stderr, \"Quantized weights are inference-only; disable training\\n\");\n"
2806 " if (layout_model(&m) != 0) {\n"
2807 " fprintf(stderr, \"layout_model failed\\n\");\n"
2810 " if (model_weights) {\n"
2811 " if (load_model_weights(model_weights, &m) != 0) {\n"
2812 " fprintf(stderr, \"failed to load model weights\\n\");\n"
2816 " if (tokens_path) {\n"
2817 " int T = m.context_window;\n"
2818 " tokens = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2820 " fprintf(stderr, \"failed to alloc tokens\\n\");\n"
2823 " if (read_ints(tokens_path, tokens, (size_t)T) != 0) {\n"
2824 " fprintf(stderr, \"failed to read tokens\\n\");\n"
2829 " if (!run_backward) {\n"
2830 " embed_tokens(&m, tokens, T);\n"
2835 " if (run_backward) {\n"
2836 " if (!litmus_targets) {\n"
2837 " fprintf(stderr, \"backward requires --targets\\n\");\n"
2840 " int T = m.context_window;\n"
2841 " targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2842 " if (!targets) {\n"
2843 " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2846 " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2847 " fprintf(stderr, \"failed to read targets\\n\");\n"
2849 " targets = NULL;\n"
2854 " dump_layout(&m, dump_all);\n"
2856 " if (run_litmus) {\n"
2857 " if (!litmus_hidden || !litmus_weights || !litmus_targets) {\n"
2858 " fprintf(stderr, \"litmus requires --hidden, --weights, and --targets\\n\");\n"
2861 " int T = m.context_window;\n"
2862 " int V = m.vocab_size;\n"
2863 " int D = m.embed_dim;\n"
2864 " int aligned_D = (int)m.aligned_embed_dim;\n"
2865 " float *hidden = ptr_f32(m.memory_base, m.final_output_offset);\n"
2866 " float *weights = ptr_f32(m.memory_base, m.lm_head_weight_offset);\n"
2867 " float *logits = ptr_f32(m.memory_base, m.logits_offset);\n"
2868 " if (read_floats(litmus_hidden, hidden, (size_t)T * aligned_D) != 0) {\n"
2869 " fprintf(stderr, \"failed to read hidden\\n\");\n"
2872 " if (read_floats(litmus_weights, weights, (size_t)V * aligned_D) != 0) {\n"
2873 " fprintf(stderr, \"failed to read weights\\n\");\n"
2876 " int32_t *targets = (int32_t *)malloc((size_t)T * sizeof(int32_t));\n"
2877 " if (!targets) {\n"
2878 " fprintf(stderr, \"failed to alloc targets\\n\");\n"
2881 " if (read_ints(litmus_targets, targets, (size_t)T) != 0) {\n"
2882 " fprintf(stderr, \"failed to read targets\\n\");\n"
2886 " float *d_logits = (float *)calloc((size_t)T * V, sizeof(float));\n"
2887 " float *d_hidden = (float *)calloc((size_t)T * aligned_D, sizeof(float));\n"
2888 " float *d_weights = (float *)calloc((size_t)V * aligned_D, sizeof(float));\n"
2889 " if (!d_logits || !d_hidden || !d_weights) {\n"
2890 " fprintf(stderr, \"failed to alloc grads\\n\");\n"
2892 " free(d_logits);\n"
2893 " free(d_hidden);\n"
2894 " free(d_weights);\n"
2897 " lm_head_forward(hidden, weights, logits, T, V, D, aligned_D);\n"
2898 " float loss = 0.0f;\n"
2899 " softmax_cross_entropy(logits, targets, T, V, d_logits, &loss);\n"
2900 " lm_head_backward(hidden, weights, d_logits, d_hidden, d_weights, T, V, D, aligned_D);\n"
2901 " if (out_logits) write_floats(out_logits, logits, (size_t)T * V);\n"
2902 " if (out_dlogits) write_floats(out_dlogits, d_logits, (size_t)T * V);\n"
2903 " if (out_dhidden) write_floats(out_dhidden, d_hidden, (size_t)T * aligned_D);\n"
2904 " if (out_dweights) write_floats(out_dweights, d_weights, (size_t)V * aligned_D);\n"
2905 " if (out_loss) write_float_scalar(out_loss, loss);\n"
2906 " if (!out_loss) printf(\"loss=%%.6f\\n\", loss);\n"
2908 " free(d_logits);\n"
2909 " free(d_hidden);\n"
2910 " free(d_weights);\n"
2911 " ck_huge_free(m.memory_base, m.total_bytes);\n"
2912 " free(m.layers);\n"
2915 " // TODO: load weights into m.memory_base using the offsets above.\n"
2916 " // TODO: write token/pos embeddings into embedded_input_offset.\n"
2917 " if (!run_backward) {\n"
2918 " if (!no_forward) {\n"
2919 " run_model_forward(&m);\n"
2922 " if (!tokens || !targets) {\n"
2923 " fprintf(stderr, \"backward requires --tokens and --targets\\n\");\n"
2926 " if (steps < 1) steps = 1;\n"
2927 " float loss = 0.0f;\n"
2928 " for (int step = 0; step < steps; ++step) {\n"
2929 " embed_tokens(&m, tokens, m.context_window);\n"
2930 " run_model_forward(&m);\n"
2931 " if (run_model_backward(&m, tokens, targets, &loss) != 0) {\n"
2932 " fprintf(stderr, \"backward failed\\n\");\n"
2935 " if (log_steps) {\n"
2936 " printf(\"step %%d loss=%%.6f\\n\", step, loss);\n"
2939 " if (out_loss) {\n"
2940 " write_float_scalar(out_loss, loss);\n"
2943 " if (out_logits) {\n"
2944 " write_floats(out_logits, ptr_f32(m.memory_base, m.logits_offset),\n"
2945 " (size_t)m.context_window * (size_t)m.vocab_size);\n"
2947 " if (out_weights) {\n"
2948 " if (save_model_weights(out_weights, &m) != 0) {\n"
2949 " fprintf(stderr, \"failed to save model weights\\n\");\n"
2953 " ck_huge_free(m.memory_base, m.total_bytes);\n"
2954 " free(m.layers);\n"
static int emit_runtime_preamble(FILE *out)
static int emit_kernel_manifest(const CKIRGraph *forward, const char *runtime_path)
static const char * ck_first_layer_buffer_name(void)
static void emit_library_api(FILE *out, const CKIRGraph *forward)
static void emit_layer_allocations(FILE *out)
static void emit_sgd_update(FILE *out)
static void emit_model_struct(FILE *out)
static void emit_global_allocations(FILE *out)
static void emit_zero_grad(FILE *out)
static void emit_layer_offsets_struct(FILE *out)
static void emit_global_aliases_to_layer(FILE *out)
int ck_ir_validate_supported(const CKIRGraph *graph)
References CK_EMIT_STANDALONE, ck_first_layer_buffer_name(), ck_ir_validate_supported(), CKIRGraph::config, CKModelConfig::context_window, emit_global_aliases_to_layer(), emit_global_allocations(), emit_kernel_manifest(), emit_layer_allocations(), emit_layer_offsets_struct(), emit_library_api(), emit_model_struct(), emit_runtime_preamble(), emit_sgd_update(), emit_zero_grad(), CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, and CKModelConfig::vocab_size.
Referenced by main().