← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_ir_v6.c
Go to the documentation of this file.
1 #include "ckernel_ir.h"
2 
3 #include <stdbool.h>
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <string.h>
7 
8 static int parse_int_field(const char *json,
9  const char *key,
10  int *out_value)
11 {
12  const char *p = strstr(json, key);
13  if (!p) {
14  return -1;
15  }
16 
17  // Move to after the key, look for the first digit or minus sign.
18  p = strchr(p, ':');
19  if (!p) {
20  return -1;
21  }
22  while (*p && (*p == ':' || *p == ' ' || *p == '\t')) {
23  ++p;
24  }
25 
26  int value = 0;
27  if (sscanf(p, "%d", &value) != 1) {
28  return -1;
29  }
30 
31  *out_value = value;
32  return 0;
33 }
34 
35 static int parse_int_field_in_range(const char *json,
36  size_t len,
37  const char *key,
38  int *out_value)
39 {
40  if (!json || !key || !out_value) {
41  return -1;
42  }
43 
44  size_t key_len = strlen(key);
45  const char *end = json + len;
46  for (const char *p = json; p + key_len <= end; ++p) {
47  if (memcmp(p, key, key_len) != 0) {
48  continue;
49  }
50 
51  const char *colon = memchr(p + key_len, ':', (size_t)(end - (p + key_len)));
52  if (!colon) {
53  return -1;
54  }
55  const char *v = colon + 1;
56  while (v < end && (*v == ' ' || *v == '\t' || *v == '\n' || *v == '\r')) {
57  ++v;
58  }
59 
60  int value = 0;
61  if (v < end && sscanf(v, "%d", &value) == 1) {
62  *out_value = value;
63  return 0;
64  }
65  return -1;
66  }
67 
68  return -1;
69 }
70 
71 static int parse_int_field_any(const char *json,
72  size_t len,
73  const char *const *keys,
74  int *out_value)
75 {
76  if (!keys) {
77  return -1;
78  }
79  for (int i = 0; keys[i]; ++i) {
80  if (parse_int_field_in_range(json, len, keys[i], out_value) == 0) {
81  return 0;
82  }
83  }
84  return -1;
85 }
86 
87 static int parse_float_field_in_range(const char *json,
88  size_t len,
89  const char *key,
90  float *out_value)
91 {
92  if (!json || !key || !out_value) {
93  return -1;
94  }
95 
96  size_t key_len = strlen(key);
97  const char *end = json + len;
98  for (const char *p = json; p + key_len <= end; ++p) {
99  if (memcmp(p, key, key_len) != 0) {
100  continue;
101  }
102 
103  const char *colon = memchr(p + key_len, ':', (size_t)(end - (p + key_len)));
104  if (!colon) {
105  return -1;
106  }
107  const char *v = colon + 1;
108  while (v < end && (*v == ' ' || *v == '\t' || *v == '\n' || *v == '\r')) {
109  ++v;
110  }
111 
112  float value = 0.0f;
113  if (v < end && sscanf(v, "%f", &value) == 1) {
114  *out_value = value;
115  return 0;
116  }
117  return -1;
118  }
119 
120  return -1;
121 }
122 
123 static int parse_float_field_any(const char *json,
124  size_t len,
125  const char *const *keys,
126  float *out_value)
127 {
128  if (!keys) {
129  return -1;
130  }
131  for (int i = 0; keys[i]; ++i) {
132  if (parse_float_field_in_range(json, len, keys[i], out_value) == 0) {
133  return 0;
134  }
135  }
136  return -1;
137 }
138 
139 static int find_object_range(const char *json,
140  const char *key,
141  const char **out_start,
142  size_t *out_len)
143 {
144  if (!json || !key || !out_start || !out_len) {
145  return -1;
146  }
147 
148  const char *p = strstr(json, key);
149  if (!p) {
150  return -1;
151  }
152 
153  const char *colon = strchr(p, ':');
154  if (!colon) {
155  return -1;
156  }
157 
158  const char *brace = strchr(colon, '{');
159  if (!brace) {
160  return -1;
161  }
162 
163  bool in_string = false;
164  bool escape = false;
165  int depth = 0;
166  const char *start = NULL;
167 
168  for (const char *cur = brace; *cur; ++cur) {
169  char c = *cur;
170  if (in_string) {
171  if (escape) {
172  escape = false;
173  continue;
174  }
175  if (c == '\\') {
176  escape = true;
177  continue;
178  }
179  if (c == '"') {
180  in_string = false;
181  }
182  continue;
183  }
184 
185  if (c == '"') {
186  in_string = true;
187  continue;
188  }
189  if (c == '{') {
190  if (depth == 0) {
191  start = cur;
192  }
193  depth++;
194  continue;
195  }
196  if (c == '}') {
197  depth--;
198  if (depth == 0) {
199  *out_start = start;
200  *out_len = (size_t)(cur - start + 1);
201  return 0;
202  }
203  }
204  }
205 
206  return -1;
207 }
208 
209 int ck_model_config_from_hf_json(const char *path, CKModelConfig *cfg)
210 {
211  if (!path || !cfg) {
212  return -1;
213  }
214 
215  FILE *f = fopen(path, "rb");
216  if (!f) {
217  perror("ck_model_config_from_hf_json: fopen");
218  return -1;
219  }
220 
221  if (fseek(f, 0, SEEK_END) != 0) {
222  fclose(f);
223  return -1;
224  }
225  long len = ftell(f);
226  if (len < 0) {
227  fclose(f);
228  return -1;
229  }
230  if (fseek(f, 0, SEEK_SET) != 0) {
231  fclose(f);
232  return -1;
233  }
234 
235  char *buf = (char *)malloc((size_t)len + 1);
236  if (!buf) {
237  fclose(f);
238  return -1;
239  }
240  size_t nread = fread(buf, 1, (size_t)len, f);
241  fclose(f);
242  buf[nread] = '\0';
243 
244  CKModelConfig tmp;
245  memset(&tmp, 0, sizeof(tmp));
246  tmp.rms_norm_eps = 1e-5f;
247  tmp.rope_theta = 0.0f;
248 
249  const char *scope = buf;
250  size_t scope_len = nread;
251  if (find_object_range(buf, "\"text_config\"", &scope, &scope_len) != 0) {
252  scope = buf;
253  scope_len = nread;
254  }
255 
256  const char *num_layers_keys[] = { "\"num_hidden_layers\"", "\"n_layer\"", NULL };
257  const char *hidden_size_keys[] = { "\"hidden_size\"", "\"n_embd\"", "\"d_model\"", NULL };
258  const char *intermediate_keys[] = { "\"intermediate_size\"", "\"n_inner\"", "\"ffn_dim\"", "\"mlp_dim\"", NULL };
259  const char *num_heads_keys[] = { "\"num_attention_heads\"", "\"n_head\"", "\"num_heads\"", NULL };
260  const char *num_kv_heads_keys[] = { "\"num_key_value_heads\"", "\"num_kv_heads\"", NULL };
261  const char *vocab_keys[] = { "\"vocab_size\"", "\"n_vocab\"", NULL };
262  const char *context_keys[] = { "\"max_position_embeddings\"", "\"n_positions\"", "\"context_length\"", "\"seq_len\"", NULL };
263  const char *rms_eps_keys[] = { "\"rms_norm_eps\"", "\"layer_norm_eps\"", NULL };
264  const char *rope_theta_keys[] = { "\"rope_theta\"", "\"rope_base\"", NULL };
265 
266  if (parse_int_field_any(scope, scope_len, num_layers_keys, &tmp.num_layers) != 0) {
267  fprintf(stderr, "Warning: num_hidden_layers not found in %s\n", path);
268  }
269  if (parse_int_field_any(scope, scope_len, hidden_size_keys, &tmp.hidden_size) != 0) {
270  fprintf(stderr, "Warning: hidden_size not found in %s\n", path);
271  }
272  if (parse_int_field_any(scope, scope_len, intermediate_keys, &tmp.intermediate_size) != 0) {
273  fprintf(stderr, "Warning: intermediate_size not found in %s\n", path);
274  }
275  if (parse_int_field_any(scope, scope_len, num_heads_keys, &tmp.num_heads) != 0) {
276  fprintf(stderr, "Warning: num_attention_heads not found in %s\n", path);
277  }
278 
279  // num_key_value_heads is optional; default to num_heads if missing.
280  if (parse_int_field_any(scope, scope_len, num_kv_heads_keys, &tmp.num_kv_heads) != 0) {
281  tmp.num_kv_heads = tmp.num_heads;
282  }
283 
284  // Optional: vocab_size
285  if (parse_int_field_any(scope, scope_len, vocab_keys, &tmp.vocab_size) != 0) {
286  tmp.vocab_size = 0;
287  }
288 
289  // Optional: context length (try max_position_embeddings, then n_positions)
290  if (parse_int_field_any(scope, scope_len, context_keys, &tmp.context_window) != 0) {
291  tmp.context_window = 0;
292  }
293  if (parse_float_field_any(scope, scope_len, rms_eps_keys, &tmp.rms_norm_eps) != 0) {
294  tmp.rms_norm_eps = 1e-5f;
295  }
296  if (parse_float_field_any(scope, scope_len, rope_theta_keys, &tmp.rope_theta) != 0) {
297  tmp.rope_theta = 0.0f;
298  }
299 
300  free(buf);
301  *cfg = tmp;
302  return 0;
303 }
304 
306 {
307  if (!cfg || !graph) {
308  return -1;
309  }
310 
311  const int L = cfg->num_layers > 0 ? cfg->num_layers : 1;
312  const int nodes_per_layer = 10; // LN1, QKV, ATT, ADD, LN2, W1, SPLIT, SWIGLU, W2, ADD
313  const int total_nodes = L * nodes_per_layer;
314 
315  CKIRNode *nodes = (CKIRNode *)calloc((size_t)total_nodes, sizeof(CKIRNode));
316  if (!nodes) {
317  return -1;
318  }
319 
320  // Sentinel producer for block inputs (e.g., h_in, past_kv) per layer.
321  const uint16_t INPUT_NODE_SENTINEL = 0xFFFF;
322 
323  for (int layer = 0; layer < L; ++layer) {
324  const int base = layer * nodes_per_layer;
325 
326  // Node 0: LN1 = RMSNorm(h_in)
327  nodes[base + 0].id.layer = (uint16_t)layer;
328  nodes[base + 0].id.node = 0;
329  nodes[base + 0].op = CK_OP_RMSNORM;
330  nodes[base + 0].inputs[0].producer.layer = (uint16_t)layer;
331  nodes[base + 0].inputs[0].producer.node = INPUT_NODE_SENTINEL; // h_in
332  nodes[base + 0].inputs[0].out_index = 0;
333  nodes[base + 0].n_inputs = 1;
334  nodes[base + 0].n_outputs = 1;
335 
336  // Node 1: QKV Linear
337  nodes[base + 1].id.layer = (uint16_t)layer;
338  nodes[base + 1].id.node = 1;
339  nodes[base + 1].op = CK_OP_LINEAR_QKV;
340  nodes[base + 1].inputs[0].producer.layer = (uint16_t)layer;
341  nodes[base + 1].inputs[0].producer.node = 0;
342  nodes[base + 1].inputs[0].out_index = 0;
343  nodes[base + 1].n_inputs = 1;
344  nodes[base + 1].n_outputs = 1;
345 
346  // Node 2: Attention
347  nodes[base + 2].id.layer = (uint16_t)layer;
348  nodes[base + 2].id.node = 2;
349  nodes[base + 2].op = CK_OP_ATTENTION;
350  nodes[base + 2].inputs[0].producer.layer = (uint16_t)layer;
351  nodes[base + 2].inputs[0].producer.node = 1; // qkv
352  nodes[base + 2].inputs[0].out_index = 0;
353  nodes[base + 2].n_inputs = 1; // past_kv omitted for now
354  nodes[base + 2].n_outputs = 1;
355 
356  // Node 3: Add residual (h_in + attn_out)
357  nodes[base + 3].id.layer = (uint16_t)layer;
358  nodes[base + 3].id.node = 3;
359  nodes[base + 3].op = CK_OP_ADD;
360  nodes[base + 3].inputs[0].producer.layer = (uint16_t)layer;
361  nodes[base + 3].inputs[0].producer.node = INPUT_NODE_SENTINEL; // h_in
362  nodes[base + 3].inputs[0].out_index = 0;
363  nodes[base + 3].inputs[1].producer.layer = (uint16_t)layer;
364  nodes[base + 3].inputs[1].producer.node = 2; // attn_out
365  nodes[base + 3].inputs[1].out_index = 0;
366  nodes[base + 3].n_inputs = 2;
367  nodes[base + 3].n_outputs = 1;
368 
369  // Node 4: LN2 = RMSNorm(residual)
370  nodes[base + 4].id.layer = (uint16_t)layer;
371  nodes[base + 4].id.node = 4;
372  nodes[base + 4].op = CK_OP_RMSNORM;
373  nodes[base + 4].inputs[0].producer.layer = (uint16_t)layer;
374  nodes[base + 4].inputs[0].producer.node = 3;
375  nodes[base + 4].inputs[0].out_index = 0;
376  nodes[base + 4].n_inputs = 1;
377  nodes[base + 4].n_outputs = 1;
378 
379  // Node 5: W1 Linear
380  nodes[base + 5].id.layer = (uint16_t)layer;
381  nodes[base + 5].id.node = 5;
382  nodes[base + 5].op = CK_OP_LINEAR;
383  nodes[base + 5].inputs[0].producer.layer = (uint16_t)layer;
384  nodes[base + 5].inputs[0].producer.node = 4;
385  nodes[base + 5].inputs[0].out_index = 0;
386  nodes[base + 5].n_inputs = 1;
387  nodes[base + 5].n_outputs = 1;
388 
389  // Node 6: Split into (a, b)
390  nodes[base + 6].id.layer = (uint16_t)layer;
391  nodes[base + 6].id.node = 6;
392  nodes[base + 6].op = CK_OP_SPLIT;
393  nodes[base + 6].inputs[0].producer.layer = (uint16_t)layer;
394  nodes[base + 6].inputs[0].producer.node = 5;
395  nodes[base + 6].inputs[0].out_index = 0;
396  nodes[base + 6].n_inputs = 1;
397  nodes[base + 6].n_outputs = 2;
398 
399  // Node 7: SwiGLU(a,b)
400  nodes[base + 7].id.layer = (uint16_t)layer;
401  nodes[base + 7].id.node = 7;
402  nodes[base + 7].op = CK_OP_SWIGLU;
403  nodes[base + 7].inputs[0].producer.layer = (uint16_t)layer;
404  nodes[base + 7].inputs[0].producer.node = 6;
405  nodes[base + 7].inputs[0].out_index = 0; // a
406  nodes[base + 7].inputs[1].producer.layer = (uint16_t)layer;
407  nodes[base + 7].inputs[1].producer.node = 6;
408  nodes[base + 7].inputs[1].out_index = 1; // b
409  nodes[base + 7].n_inputs = 2;
410  nodes[base + 7].n_outputs = 1;
411 
412  // Node 8: W2 Linear
413  nodes[base + 8].id.layer = (uint16_t)layer;
414  nodes[base + 8].id.node = 8;
415  nodes[base + 8].op = CK_OP_LINEAR;
416  nodes[base + 8].inputs[0].producer.layer = (uint16_t)layer;
417  nodes[base + 8].inputs[0].producer.node = 7;
418  nodes[base + 8].inputs[0].out_index = 0;
419  nodes[base + 8].n_inputs = 1;
420  nodes[base + 8].n_outputs = 1;
421 
422  // Node 9: Add residual: out = residual + mlp_out
423  nodes[base + 9].id.layer = (uint16_t)layer;
424  nodes[base + 9].id.node = 9;
425  nodes[base + 9].op = CK_OP_ADD;
426  nodes[base + 9].inputs[0].producer.layer = (uint16_t)layer;
427  nodes[base + 9].inputs[0].producer.node = 3; // first residual output
428  nodes[base + 9].inputs[0].out_index = 0;
429  nodes[base + 9].inputs[1].producer.layer = (uint16_t)layer;
430  nodes[base + 9].inputs[1].producer.node = 8; // mlp_out
431  nodes[base + 9].inputs[1].out_index = 0;
432  nodes[base + 9].n_inputs = 2;
433  nodes[base + 9].n_outputs = 1;
434  }
435 
436  graph->config = *cfg;
437  graph->num_nodes = total_nodes;
438  graph->nodes = nodes;
439  return 0;
440 }
441 
443 {
444  switch (op) {
445  case CK_OP_RMSNORM: return CK_OP_RMSNORM_BWD;
448  case CK_OP_ADD: return CK_OP_ADD_BWD;
449  case CK_OP_LINEAR: return CK_OP_LINEAR_BWD;
450  case CK_OP_SPLIT: return CK_OP_SPLIT_BWD;
451  case CK_OP_SWIGLU: return CK_OP_SWIGLU_BWD;
452  default: return op;
453  }
454 }
455 
456 int ck_build_decoder_backward_ir(const CKIRGraph *forward, CKIRGraph *backward)
457 {
458  if (!forward || !backward) {
459  return -1;
460  }
461  if (forward->num_nodes <= 0 || !forward->nodes) {
462  return -1;
463  }
464 
465  const int N = forward->num_nodes;
466  CKIRNode *nodes = (CKIRNode *)calloc((size_t)N, sizeof(CKIRNode));
467  if (!nodes) {
468  return -1;
469  }
470 
471  for (int i = 0; i < N; ++i) {
472  const CKIRNode *f = &forward->nodes[N - 1 - i]; // reverse order
473  CKIRNode *b = &nodes[i];
474 
475  b->id = f->id; // same layer/node id
476  b->op = map_forward_to_backward(f->op);
477  b->n_inputs = f->n_outputs; // placeholder
478  b->n_outputs= f->n_inputs; // placeholder
479 
480  // For now we simply copy the forward inputs to keep a reference
481  // to which activations this backward op relates to.
482  for (int j = 0; j < f->n_inputs; ++j) {
483  b->inputs[j] = f->inputs[j];
484  }
485  }
486 
487  backward->config = forward->config;
488  backward->num_nodes = N;
489  backward->nodes = nodes;
490  return 0;
491 }
492 
493 void ck_ir_free(CKIRGraph *graph)
494 {
495  if (!graph) {
496  return;
497  }
498  free(graph->nodes);
499  graph->nodes = NULL;
500  graph->num_nodes = 0;
501 }
502 
503 static const char *op_name(CKOpType op)
504 {
505  switch (op) {
506  case CK_OP_RMSNORM: return "RMSNORM";
507  case CK_OP_LINEAR_QKV: return "LINEAR_QKV";
508  case CK_OP_ATTENTION: return "ATTENTION";
509  case CK_OP_ADD: return "ADD";
510  case CK_OP_LINEAR: return "LINEAR";
511  case CK_OP_SPLIT: return "SPLIT";
512  case CK_OP_SWIGLU: return "SWIGLU";
513  case CK_OP_RMSNORM_BWD: return "RMSNORM_BWD";
514  case CK_OP_LINEAR_QKV_BWD: return "LINEAR_QKV_BWD";
515  case CK_OP_ATTENTION_BWD: return "ATTENTION_BWD";
516  case CK_OP_ADD_BWD: return "ADD_BWD";
517  case CK_OP_LINEAR_BWD: return "LINEAR_BWD";
518  case CK_OP_SPLIT_BWD: return "SPLIT_BWD";
519  case CK_OP_SWIGLU_BWD: return "SWIGLU_BWD";
520  default: return "UNKNOWN";
521  }
522 }
523 
524 void ck_ir_dump(const CKIRGraph *graph, FILE *out)
525 {
526  if (!graph || !out) {
527  return;
528  }
529 
530  fprintf(out,
531  "CKIRGraph: layers=%d, hidden_size=%d, intermediate_size=%d, heads=%d, kv_heads=%d, vocab=%d, ctx=%d, eps=%.6g, rope_theta=%.6g\n",
532  graph->config.num_layers,
533  graph->config.hidden_size,
534  graph->config.intermediate_size,
535  graph->config.num_heads,
536  graph->config.num_kv_heads,
537  graph->config.vocab_size,
538  graph->config.context_window,
539  graph->config.rms_norm_eps,
540  graph->config.rope_theta);
541 
542  for (int i = 0; i < graph->num_nodes; ++i) {
543  const CKIRNode *n = &graph->nodes[i];
544  fprintf(out, " L%u N%u %-14s outputs=[",
545  (unsigned)n->id.layer,
546  (unsigned)n->id.node,
547  op_name(n->op));
548  for (int o = 0; o < n->n_outputs; ++o) {
549  if (o > 0) {
550  fputc(',', out);
551  }
552  fprintf(out, "L%u:N%u:%d",
553  (unsigned)n->id.layer,
554  (unsigned)n->id.node,
555  o);
556  }
557  fprintf(out, "] inputs=[");
558  for (int j = 0; j < n->n_inputs; ++j) {
559  const CKInputRef *inp = &n->inputs[j];
560  if (j > 0) {
561  fputc(',', out);
562  }
563  if (inp->producer.node == 0xFFFFu) {
564  fprintf(out, "IN");
565  } else {
566  fprintf(out, "L%u:N%u",
567  (unsigned)inp->producer.layer,
568  (unsigned)inp->producer.node);
569  }
570  }
571  fprintf(out, "]\n");
572  }
573 }
574 
575 int ck_ir_serialize_json(const CKIRGraph *graph, const char *path)
576 {
577  if (!graph || !path) {
578  return -1;
579  }
580 
581  FILE *f = fopen(path, "wb");
582  if (!f) {
583  perror("ck_ir_serialize_json: fopen");
584  return -1;
585  }
586 
587  fprintf(f, "{\n");
588  fprintf(f, " \"config\": {\n");
589  fprintf(f, " \"num_layers\": %d,\n", graph->config.num_layers);
590  fprintf(f, " \"hidden_size\": %d,\n", graph->config.hidden_size);
591  fprintf(f, " \"intermediate_size\": %d,\n", graph->config.intermediate_size);
592  fprintf(f, " \"num_attention_heads\": %d,\n", graph->config.num_heads);
593  fprintf(f, " \"num_key_value_heads\": %d,\n", graph->config.num_kv_heads);
594  fprintf(f, " \"vocab_size\": %d,\n", graph->config.vocab_size);
595  fprintf(f, " \"context_window\": %d,\n", graph->config.context_window);
596  fprintf(f, " \"rms_norm_eps\": %.9g,\n", graph->config.rms_norm_eps);
597  fprintf(f, " \"rope_theta\": %.9g\n", graph->config.rope_theta);
598  fprintf(f, " },\n");
599 
600  // For now we only emit a flat "nodes" array. Higher-level tools can
601  // reorganize this into header/block/footer with per-layer arrays.
602  fprintf(f, " \"nodes\": [\n");
603  for (int i = 0; i < graph->num_nodes; ++i) {
604  const CKIRNode *n = &graph->nodes[i];
605  fprintf(f, " {\n");
606  fprintf(f, " \"layer\": %u,\n", (unsigned)n->id.layer);
607  fprintf(f, " \"node\": %u,\n", (unsigned)n->id.node);
608  fprintf(f, " \"op\": \"%s\",\n", op_name(n->op));
609 
610  // Outputs: derive labels L<layer>:N<node>:slot
611  fprintf(f, " \"outputs\": [");
612  for (int o = 0; o < n->n_outputs; ++o) {
613  if (o > 0) fprintf(f, ", ");
614  fprintf(f, "\"L%u:N%u:%d\"",
615  (unsigned)n->id.layer,
616  (unsigned)n->id.node,
617  o);
618  }
619  fprintf(f, "],\n");
620 
621  // Inputs: either "IN" or L<layer>:N<node>:slot
622  fprintf(f, " \"inputs\": [");
623  for (int j = 0; j < n->n_inputs; ++j) {
624  const CKInputRef *inp = &n->inputs[j];
625  if (j > 0) fprintf(f, ", ");
626  if (inp->producer.node == 0xFFFFu) {
627  fprintf(f, "\"IN\"");
628  } else {
629  fprintf(f, "\"L%u:N%u:%u\"",
630  (unsigned)inp->producer.layer,
631  (unsigned)inp->producer.node,
632  (unsigned)inp->out_index);
633  }
634  }
635  fprintf(f, "]\n");
636 
637  fprintf(f, " }%s\n", (i + 1 < graph->num_nodes) ? "," : "");
638  }
639  fprintf(f, " ]\n");
640  fprintf(f, "}\n");
641 
642  fclose(f);
643  return 0;
644 }
645 
646 static CKOpType parse_op(const char *s)
647 {
648  if (strcmp(s, "RMSNORM") == 0) return CK_OP_RMSNORM;
649  if (strcmp(s, "LINEAR_QKV") == 0) return CK_OP_LINEAR_QKV;
650  if (strcmp(s, "ATTENTION") == 0) return CK_OP_ATTENTION;
651  if (strcmp(s, "ADD") == 0) return CK_OP_ADD;
652  if (strcmp(s, "LINEAR") == 0) return CK_OP_LINEAR;
653  if (strcmp(s, "SPLIT") == 0) return CK_OP_SPLIT;
654  if (strcmp(s, "SWIGLU") == 0) return CK_OP_SWIGLU;
655  if (strcmp(s, "RMSNORM_BWD") == 0) return CK_OP_RMSNORM_BWD;
656  if (strcmp(s, "LINEAR_QKV_BWD") == 0) return CK_OP_LINEAR_QKV_BWD;
657  if (strcmp(s, "ATTENTION_BWD") == 0) return CK_OP_ATTENTION_BWD;
658  if (strcmp(s, "ADD_BWD") == 0) return CK_OP_ADD_BWD;
659  if (strcmp(s, "LINEAR_BWD") == 0) return CK_OP_LINEAR_BWD;
660  if (strcmp(s, "SPLIT_BWD") == 0) return CK_OP_SPLIT_BWD;
661  if (strcmp(s, "SWIGLU_BWD") == 0) return CK_OP_SWIGLU_BWD;
662  return CK_OP_RMSNORM; // default fallback
663 }
664 
665 int ck_ir_parse_json(const char *path, CKIRGraph *graph)
666 {
667  if (!path || !graph) {
668  return -1;
669  }
670 
671  FILE *f = fopen(path, "rb");
672  if (!f) {
673  perror("ck_ir_parse_json: fopen");
674  return -1;
675  }
676 
677  if (fseek(f, 0, SEEK_END) != 0) {
678  fclose(f);
679  return -1;
680  }
681  long len = ftell(f);
682  if (len < 0) {
683  fclose(f);
684  return -1;
685  }
686  if (fseek(f, 0, SEEK_SET) != 0) {
687  fclose(f);
688  return -1;
689  }
690 
691  char *buf = (char *)malloc((size_t)len + 1);
692  if (!buf) {
693  fclose(f);
694  return -1;
695  }
696  size_t nread = fread(buf, 1, (size_t)len, f);
697  fclose(f);
698  buf[nread] = '\0';
699 
700  CKIRGraph tmp;
701  memset(&tmp, 0, sizeof(tmp));
702 
703  // Parse config using the same helper as HF-style JSON.
704  if (parse_int_field(buf, "\"num_layers\"", &tmp.config.num_layers) != 0) {
705  fprintf(stderr, "ck_ir_parse_json: missing num_layers\n");
706  }
707  if (parse_int_field(buf, "\"hidden_size\"", &tmp.config.hidden_size) != 0) {
708  fprintf(stderr, "ck_ir_parse_json: missing hidden_size\n");
709  }
710  if (parse_int_field(buf, "\"intermediate_size\"", &tmp.config.intermediate_size) != 0) {
711  fprintf(stderr, "ck_ir_parse_json: missing intermediate_size\n");
712  }
713  if (parse_int_field(buf, "\"num_attention_heads\"", &tmp.config.num_heads) != 0) {
714  fprintf(stderr, "ck_ir_parse_json: missing num_attention_heads\n");
715  }
716  if (parse_int_field(buf, "\"num_key_value_heads\"", &tmp.config.num_kv_heads) != 0) {
718  }
719 
720  // Optional: vocab_size / context_window may be present
721  if (parse_int_field(buf, "\"vocab_size\"", &tmp.config.vocab_size) != 0) {
722  tmp.config.vocab_size = 0;
723  }
724  if (parse_int_field(buf, "\"context_window\"", &tmp.config.context_window) != 0) {
725  tmp.config.context_window = 0;
726  }
727  if (parse_float_field_in_range(buf, nread, "\"rms_norm_eps\"", &tmp.config.rms_norm_eps) != 0) {
728  tmp.config.rms_norm_eps = 1e-5f;
729  }
730  if (parse_float_field_in_range(buf, nread, "\"rope_theta\"", &tmp.config.rope_theta) != 0) {
731  tmp.config.rope_theta = 0.0f;
732  }
733 
734  // Count nodes by scanning for "layer" keys in the nodes array.
735  char *nodes_begin = strstr(buf, "\"nodes\"");
736  if (!nodes_begin) {
737  free(buf);
738  return -1;
739  }
740  char *p = nodes_begin;
741  int count = 0;
742  while ((p = strstr(p, "\"layer\"")) != NULL) {
743  count++;
744  p += 7;
745  }
746  if (count <= 0) {
747  free(buf);
748  return -1;
749  }
750 
751  CKIRNode *nodes = (CKIRNode *)calloc((size_t)count, sizeof(CKIRNode));
752  if (!nodes) {
753  free(buf);
754  return -1;
755  }
756 
757  // Parse each node sequentially.
758  p = nodes_begin;
759  for (int i = 0; i < count; ++i) {
760  // layer
761  char *pl = strstr(p, "\"layer\"");
762  if (!pl) { free(nodes); free(buf); return -1; }
763  int layer = 0;
764  if (sscanf(strchr(pl, ':'), " : %d", &layer) != 1) {
765  free(nodes); free(buf); return -1;
766  }
767 
768  // node
769  char *pn = strstr(pl, "\"node\"");
770  if (!pn) { free(nodes); free(buf); return -1; }
771  int node = 0;
772  if (sscanf(strchr(pn, ':'), " : %d", &node) != 1) {
773  free(nodes); free(buf); return -1;
774  }
775 
776  // op string
777  char *po = strstr(pn, "\"op\"");
778  if (!po) { free(nodes); free(buf); return -1; }
779  char op_str[64] = {0};
780  if (sscanf(strchr(po, ':'), " : \"%63[^\"]\"", op_str) != 1) {
781  free(nodes); free(buf); return -1;
782  }
783 
784  CKIRNode *n = &nodes[i];
785  n->id.layer = (uint16_t)layer;
786  n->id.node = (uint16_t)node;
787  n->op = parse_op(op_str);
788 
789  // outputs: count entries between [ ... ]
790  char *pout = strstr(po, "\"outputs\"");
791  if (!pout) { free(nodes); free(buf); return -1; }
792  char *bo = strchr(pout, '[');
793  char *eo = strchr(pout, ']');
794  int out_count = 0;
795  if (bo && eo && eo > bo) {
796  char *q = bo;
797  while ((q = strchr(q, '"')) && q < eo) {
798  out_count++;
799  q = strchr(q + 1, '"');
800  if (!q || q >= eo) break;
801  // Skip closing quote
802  q++;
803  }
804  }
805  n->n_outputs = (uint8_t)out_count;
806 
807  // inputs
808  char *pin = strstr(po, "\"inputs\"");
809  if (!pin) { free(nodes); free(buf); return -1; }
810  char *bi = strchr(pin, '[');
811  char *ei = strchr(pin, ']');
812  int in_count = 0;
813  if (bi && ei && ei > bi) {
814  // Simple scan for tokens "IN" or "Lx:Nx:s"
815  char *q = bi;
816  while ((q = strchr(q, '"')) && q < ei) {
817  char tok[64] = {0};
818  if (sscanf(q, "\"%63[^\"]\"", tok) != 1) {
819  break;
820  }
821  if (strcmp(tok, "IN") == 0) {
822  n->inputs[in_count].producer.layer = (uint16_t)layer;
823  n->inputs[in_count].producer.node = 0xFFFFu;
824  n->inputs[in_count].out_index = 0;
825  } else {
826  unsigned plh = 0, pnn = 0, slot = 0;
827  if (sscanf(tok, "L%u:N%u:%u", &plh, &pnn, &slot) == 3) {
828  n->inputs[in_count].producer.layer = (uint16_t)plh;
829  n->inputs[in_count].producer.node = (uint16_t)pnn;
830  n->inputs[in_count].out_index = (uint8_t)slot;
831  }
832  }
833  in_count++;
834  // Move q past this token
835  q = strchr(q + 1, '"');
836  if (!q || q >= ei) break;
837  q++;
838  }
839  }
840  n->n_inputs = (uint8_t)in_count;
841 
842  // Move p forward for the next iteration
843  p = pin + 7;
844  }
845 
846  free(buf);
847  graph->config = tmp.config;
848  graph->num_nodes = count;
849  graph->nodes = nodes;
850  return 0;
851 }
CKOpType
Definition: ckernel_ir.h:35
@ CK_OP_LINEAR_BWD
Definition: ckernel_ir.h:48
@ CK_OP_SWIGLU
Definition: ckernel_ir.h:42
@ CK_OP_RMSNORM_BWD
Definition: ckernel_ir.h:44
@ CK_OP_SWIGLU_BWD
Definition: ckernel_ir.h:50
@ CK_OP_ADD
Definition: ckernel_ir.h:39
@ CK_OP_SPLIT
Definition: ckernel_ir.h:41
@ CK_OP_LINEAR_QKV_BWD
Definition: ckernel_ir.h:45
@ CK_OP_ATTENTION_BWD
Definition: ckernel_ir.h:46
@ CK_OP_SPLIT_BWD
Definition: ckernel_ir.h:49
@ CK_OP_LINEAR_QKV
Definition: ckernel_ir.h:37
@ CK_OP_LINEAR
Definition: ckernel_ir.h:40
@ CK_OP_RMSNORM
Definition: ckernel_ir.h:36
@ CK_OP_ADD_BWD
Definition: ckernel_ir.h:47
@ CK_OP_ATTENTION
Definition: ckernel_ir.h:38
int ck_build_decoder_ir(const CKModelConfig *cfg, CKIRGraph *graph)
static const char * op_name(CKOpType op)
int ck_ir_serialize_json(const CKIRGraph *graph, const char *path)
int ck_build_decoder_backward_ir(const CKIRGraph *forward, CKIRGraph *backward)
static int parse_int_field_in_range(const char *json, size_t len, const char *key, int *out_value)
Definition: ckernel_ir_v6.c:35
static int parse_float_field_in_range(const char *json, size_t len, const char *key, float *out_value)
Definition: ckernel_ir_v6.c:87
static int parse_int_field(const char *json, const char *key, int *out_value)
Definition: ckernel_ir_v6.c:8
void ck_ir_dump(const CKIRGraph *graph, FILE *out)
static int parse_float_field_any(const char *json, size_t len, const char *const *keys, float *out_value)
static CKOpType map_forward_to_backward(CKOpType op)
static CKOpType parse_op(const char *s)
int ck_model_config_from_hf_json(const char *path, CKModelConfig *cfg)
int ck_ir_parse_json(const char *path, CKIRGraph *graph)
static int parse_int_field_any(const char *json, size_t len, const char *const *keys, int *out_value)
Definition: ckernel_ir_v6.c:71
static int find_object_range(const char *json, const char *key, const char **out_start, size_t *out_len)
void ck_ir_free(CKIRGraph *graph)
CKIRNode * nodes
Definition: ckernel_ir.h:75
int num_nodes
Definition: ckernel_ir.h:74
CKModelConfig config
Definition: ckernel_ir.h:73
CKOpType op
Definition: ckernel_ir.h:65
CKInputRef inputs[4]
Definition: ckernel_ir.h:66
CKKernelId id
Definition: ckernel_ir.h:64
uint8_t n_inputs
Definition: ckernel_ir.h:67
uint8_t n_outputs
Definition: ckernel_ir.h:68
CKKernelId producer
Definition: ckernel_ir.h:59
uint8_t out_index
Definition: ckernel_ir.h:60
uint16_t node
Definition: ckernel_ir.h:55
uint16_t layer
Definition: ckernel_ir.h:54
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37
float rms_norm_eps
Definition: ckernel_ir.h:31
float rope_theta
Definition: ckernel_ir.h:32
const int32_t int int * out_len
Definition: tokenizer.h:445
uint32_t end
Definition: utf8.c:215
uint32_t start
Definition: utf8.c:214