← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_ir_v2_lower.c
Go to the documentation of this file.
1 #include "ckernel_ir_v2_lower.h"
2 
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <string.h>
6 
7 static char *ck_ir_v2_lower_strdup(const char *s)
8 {
9  if (!s) {
10  return NULL;
11  }
12  size_t len = strlen(s);
13  char *out = (char *)malloc(len + 1);
14  if (!out) {
15  return NULL;
16  }
17  memcpy(out, s, len);
18  out[len] = '\0';
19  return out;
20 }
21 
23 {
24  switch (mode) {
26  return "prefill";
28  return "decode";
30  return "backward";
31  default:
32  return "unknown";
33  }
34 }
35 
36 int ck_ir_v2_lower_mode_from_string(const char *name, CKIRV2LowerMode *out_mode)
37 {
38  if (!name || !out_mode) {
39  return -1;
40  }
41  if (strcmp(name, "prefill") == 0) {
42  *out_mode = CK_IR_V2_LOWER_PREFILL;
43  return 0;
44  }
45  if (strcmp(name, "decode") == 0) {
46  *out_mode = CK_IR_V2_LOWER_DECODE;
47  return 0;
48  }
49  if (strcmp(name, "backward") == 0) {
50  *out_mode = CK_IR_V2_LOWER_BACKWARD;
51  return 0;
52  }
53  return -1;
54 }
55 
57 {
58  if (!node) {
59  return 0;
60  }
61  if (node->condition) {
62  const char *cond = node->condition;
63  const char *cur = cond;
64  while (*cur == ' ' || *cur == '\t') {
65  cur++;
66  }
67  if (strncmp(cur, "mode", 4) == 0) {
68  cur += 4;
69  while (*cur == ' ' || *cur == '\t') {
70  cur++;
71  }
72  if (cur[0] == '=' && cur[1] == '=') {
73  cur += 2;
74  while (*cur == ' ' || *cur == '\t') {
75  cur++;
76  }
77  char mode_buf[16] = {0};
78  int idx = 0;
79  while (*cur && idx < (int)(sizeof(mode_buf) - 1) &&
80  ((*cur >= 'a' && *cur <= 'z') || *cur == '_' || *cur == '-')) {
81  mode_buf[idx++] = *cur++;
82  }
83  mode_buf[idx] = '\0';
84  CKIRV2LowerMode target;
85  if (ck_ir_v2_lower_mode_from_string(mode_buf, &target) == 0) {
86  return mode == target;
87  }
88  }
89  }
90  if (strcmp(node->condition, "training_enabled") == 0 ||
91  strcmp(node->condition, "backward_only") == 0) {
92  return mode == CK_IR_V2_LOWER_BACKWARD;
93  }
94  if (strcmp(node->condition, "inference_only") == 0) {
95  return mode != CK_IR_V2_LOWER_BACKWARD;
96  }
97  if (strcmp(node->condition, "prefill_only") == 0) {
98  return mode == CK_IR_V2_LOWER_PREFILL;
99  }
100  if (strcmp(node->condition, "decode_only") == 0) {
101  return mode == CK_IR_V2_LOWER_DECODE;
102  }
103  }
104  return 1;
105 }
106 
107 static int ck_ir_v2_lower_copy_buffers(const CKIRV2Graph *input, CKIRV2Graph *output)
108 {
109  output->num_buffers = input->num_buffers;
110  output->buffers = (CKIRV2Buffer *)calloc((size_t)output->num_buffers, sizeof(CKIRV2Buffer));
111  if (!output->buffers) {
112  return -1;
113  }
114  for (int i = 0; i < input->num_buffers; ++i) {
115  const CKIRV2Buffer *src = &input->buffers[i];
116  CKIRV2Buffer *dst = &output->buffers[i];
117  dst->name = ck_ir_v2_lower_strdup(src->name);
118  dst->scope = src->scope;
119  dst->role = src->role;
120  dst->dtype = src->dtype;
121  memcpy(dst->shape, src->shape, sizeof(dst->shape));
122  dst->optional = src->optional;
125  }
126  return 0;
127 }
128 
129 static int ck_ir_v2_lower_copy_nodes(const CKIRV2Graph *input,
130  CKIRV2LowerMode mode,
131  CKIRV2Graph *output)
132 {
133  int count = 0;
134  for (int i = 0; i < input->num_nodes; ++i) {
135  if (ck_ir_v2_lower_node_enabled(&input->nodes[i], mode)) {
136  count++;
137  }
138  }
139  output->num_nodes = count;
140  output->nodes = (CKIRV2Node *)calloc((size_t)count, sizeof(CKIRV2Node));
141  if (!output->nodes) {
142  return -1;
143  }
144  int idx = 0;
145  for (int i = 0; i < input->num_nodes; ++i) {
146  const CKIRV2Node *src = &input->nodes[i];
147  if (!ck_ir_v2_lower_node_enabled(src, mode)) {
148  continue;
149  }
150  CKIRV2Node *dst = &output->nodes[idx++];
151  dst->op = ck_ir_v2_lower_strdup(src->op);
152  dst->kernel = ck_ir_v2_lower_strdup(src->kernel);
153  dst->kernel_dtype = src->kernel_dtype;
155  dst->layer = src->layer;
156  dst->flags = src->flags;
157  if (mode != CK_IR_V2_LOWER_BACKWARD) {
158  dst->flags = (uint8_t)(dst->flags | CK_IR_V2_NODE_INFERENCE_ONLY);
159  }
160  dst->n_bindings = src->n_bindings;
161  for (int b = 0; b < src->n_bindings; ++b) {
162  dst->bindings[b].arg = ck_ir_v2_lower_strdup(src->bindings[b].arg);
163  dst->bindings[b].buffer = src->bindings[b].buffer;
164  }
165  dst->n_inputs = src->n_inputs;
166  dst->n_outputs = src->n_outputs;
167  for (int j = 0; j < dst->n_inputs; ++j) {
168  dst->inputs[j] = src->inputs[j];
169  }
170  for (int j = 0; j < dst->n_outputs; ++j) {
171  dst->outputs[j] = src->outputs[j];
172  }
173  }
174  return 0;
175 }
176 
178  CKIRV2LowerMode mode,
179  CKIRV2Graph *output,
180  CKMemPlan *plan)
181 {
182  if (!input || !output || !plan) {
183  return -1;
184  }
185  memset(output, 0, sizeof(*output));
186  memset(plan, 0, sizeof(*plan));
187 
188  output->config = input->config;
189  output->has_pos_emb = input->has_pos_emb;
190  output->tie_word_embeddings = input->tie_word_embeddings;
191  output->fused_qkv = input->fused_qkv;
192  output->gated_mlp = input->gated_mlp;
193 
194  if (ck_ir_v2_lower_copy_buffers(input, output) != 0 ||
195  ck_ir_v2_lower_copy_nodes(input, mode, output) != 0) {
196  ck_ir_v2_free(output);
197  return -1;
198  }
199 
200  int tokens_override = (mode == CK_IR_V2_LOWER_DECODE) ? 1 : -1;
201  int rc = 0;
202  if (mode == CK_IR_V2_LOWER_BACKWARD) {
203  rc = ck_mem_plan_build_training_with_tokens(output, plan,
205  tokens_override);
206  } else {
207  rc = ck_mem_plan_build_inference_with_tokens(output, plan,
209  tokens_override);
210  }
211  if (rc != 0) {
212  ck_ir_v2_free(output);
213  ck_mem_plan_free(plan);
214  return -1;
215  }
216  return 0;
217 }
218 
220  CKIRV2LowerMode mode,
221  const char *path)
222 {
223  if (!input || !path) {
224  return -1;
225  }
226  CKIRV2Graph lowered = {0};
227  CKMemPlan plan = {0};
228  if (ck_ir_v2_lower_graph(input, mode, &lowered, &plan) != 0) {
229  return -1;
230  }
231  int tokens_override = (mode == CK_IR_V2_LOWER_DECODE) ? 1 : -1;
232  int base_context = (tokens_override >= 0) ? input->config.context_window : -1;
233  int rc = ck_ir_v2_serialize_json_with_plan(&lowered, &plan,
235  tokens_override,
236  base_context,
237  path);
238  ck_ir_v2_free(&lowered);
239  ck_mem_plan_free(&plan);
240  return rc;
241 }
int ck_ir_v2_serialize_json_with_plan(const CKIRV2Graph *graph, const struct CKMemPlan *plan, const char *mode, int tokens_override, int base_context_window, const char *path)
void ck_ir_v2_free(CKIRV2Graph *graph)
Definition: ckernel_ir_v2.c:34
@ CK_IR_V2_NODE_INFERENCE_ONLY
Definition: ckernel_ir_v2.h:21
static int ck_ir_v2_lower_copy_nodes(const CKIRV2Graph *input, CKIRV2LowerMode mode, CKIRV2Graph *output)
static int ck_ir_v2_lower_node_enabled(const CKIRV2Node *node, CKIRV2LowerMode mode)
static char * ck_ir_v2_lower_strdup(const char *s)
const char * ck_ir_v2_lower_mode_name(CKIRV2LowerMode mode)
int ck_ir_v2_lower_emit_json(const CKIRV2Graph *input, CKIRV2LowerMode mode, const char *path)
int ck_ir_v2_lower_mode_from_string(const char *name, CKIRV2LowerMode *out_mode)
int ck_ir_v2_lower_graph(const CKIRV2Graph *input, CKIRV2LowerMode mode, CKIRV2Graph *output, CKMemPlan *plan)
static int ck_ir_v2_lower_copy_buffers(const CKIRV2Graph *input, CKIRV2Graph *output)
CKIRV2LowerMode
@ CK_IR_V2_LOWER_BACKWARD
@ CK_IR_V2_LOWER_DECODE
@ CK_IR_V2_LOWER_PREFILL
int ck_mem_plan_build_training_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
void ck_mem_plan_free(CKMemPlan *plan)
#define CK_MEM_PLAN_DEFAULT_ALIGN
int ck_mem_plan_build_inference_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
int32_t buffer
Definition: ckernel_ir_v2.h:37
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
CKBufferRole role
Definition: ckernel_ir_v2.h:27
char * condition
Definition: ckernel_ir_v2.h:32
CKBufferScope scope
Definition: ckernel_ir_v2.h:26
char * alias_of
Definition: ckernel_ir_v2.h:31
CKDataType dtype
Definition: ckernel_ir_v2.h:28
CKModelConfig config
Definition: ckernel_ir_v2.h:56
int tie_word_embeddings
Definition: ckernel_ir_v2.h:58
CKIRV2Node * nodes
Definition: ckernel_ir_v2.h:64
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62
uint8_t n_bindings
Definition: ckernel_ir_v2.h:48
int32_t inputs[8]
Definition: ckernel_ir_v2.h:49
char * condition
Definition: ckernel_ir_v2.h:44
uint16_t layer
Definition: ckernel_ir_v2.h:45
uint8_t flags
Definition: ckernel_ir_v2.h:46
CKIRV2Binding bindings[24]
Definition: ckernel_ir_v2.h:47
int32_t outputs[4]
Definition: ckernel_ir_v2.h:51
CKDataType kernel_dtype
Definition: ckernel_ir_v2.h:43
uint8_t n_outputs
Definition: ckernel_ir_v2.h:52
char * kernel
Definition: ckernel_ir_v2.h:42
uint8_t n_inputs
Definition: ckernel_ir_v2.h:50
int context_window
Definition: ckernel_ir.h:30