← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_ir_v2_lower.c File Reference
#include "ckernel_ir_v2_lower.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

Go to the source code of this file.

Functions

static int ck_ir_v2_lower_copy_buffers (const CKIRV2Graph *input, CKIRV2Graph *output)
 
static int ck_ir_v2_lower_copy_nodes (const CKIRV2Graph *input, CKIRV2LowerMode mode, CKIRV2Graph *output)
 
int ck_ir_v2_lower_emit_json (const CKIRV2Graph *input, CKIRV2LowerMode mode, const char *path)
 
int ck_ir_v2_lower_graph (const CKIRV2Graph *input, CKIRV2LowerMode mode, CKIRV2Graph *output, CKMemPlan *plan)
 
int ck_ir_v2_lower_mode_from_string (const char *name, CKIRV2LowerMode *out_mode)
 
const char * ck_ir_v2_lower_mode_name (CKIRV2LowerMode mode)
 
static int ck_ir_v2_lower_node_enabled (const CKIRV2Node *node, CKIRV2LowerMode mode)
 
static char * ck_ir_v2_lower_strdup (const char *s)
 

Function Documentation

◆ ck_ir_v2_lower_copy_buffers()

static int ck_ir_v2_lower_copy_buffers ( const CKIRV2Graph input,
CKIRV2Graph output 
)
static

Definition at line 107 of file ckernel_ir_v2_lower.c.

108 {
109  output->num_buffers = input->num_buffers;
110  output->buffers = (CKIRV2Buffer *)calloc((size_t)output->num_buffers, sizeof(CKIRV2Buffer));
111  if (!output->buffers) {
112  return -1;
113  }
114  for (int i = 0; i < input->num_buffers; ++i) {
115  const CKIRV2Buffer *src = &input->buffers[i];
116  CKIRV2Buffer *dst = &output->buffers[i];
117  dst->name = ck_ir_v2_lower_strdup(src->name);
118  dst->scope = src->scope;
119  dst->role = src->role;
120  dst->dtype = src->dtype;
121  memcpy(dst->shape, src->shape, sizeof(dst->shape));
122  dst->optional = src->optional;
125  }
126  return 0;
127 }
static char * ck_ir_v2_lower_strdup(const char *s)
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
CKBufferRole role
Definition: ckernel_ir_v2.h:27
char * condition
Definition: ckernel_ir_v2.h:32
CKBufferScope scope
Definition: ckernel_ir_v2.h:26
char * alias_of
Definition: ckernel_ir_v2.h:31
CKDataType dtype
Definition: ckernel_ir_v2.h:28
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62

References CKIRV2Buffer::alias_of, CKIRV2Graph::buffers, ck_ir_v2_lower_strdup(), CKIRV2Buffer::condition, CKIRV2Buffer::dtype, CKIRV2Buffer::name, CKIRV2Graph::num_buffers, CKIRV2Buffer::optional, CKIRV2Buffer::role, CKIRV2Buffer::scope, and CKIRV2Buffer::shape.

Referenced by ck_ir_v2_lower_graph().

◆ ck_ir_v2_lower_copy_nodes()

static int ck_ir_v2_lower_copy_nodes ( const CKIRV2Graph input,
CKIRV2LowerMode  mode,
CKIRV2Graph output 
)
static

Definition at line 129 of file ckernel_ir_v2_lower.c.

132 {
133  int count = 0;
134  for (int i = 0; i < input->num_nodes; ++i) {
135  if (ck_ir_v2_lower_node_enabled(&input->nodes[i], mode)) {
136  count++;
137  }
138  }
139  output->num_nodes = count;
140  output->nodes = (CKIRV2Node *)calloc((size_t)count, sizeof(CKIRV2Node));
141  if (!output->nodes) {
142  return -1;
143  }
144  int idx = 0;
145  for (int i = 0; i < input->num_nodes; ++i) {
146  const CKIRV2Node *src = &input->nodes[i];
147  if (!ck_ir_v2_lower_node_enabled(src, mode)) {
148  continue;
149  }
150  CKIRV2Node *dst = &output->nodes[idx++];
151  dst->op = ck_ir_v2_lower_strdup(src->op);
152  dst->kernel = ck_ir_v2_lower_strdup(src->kernel);
153  dst->kernel_dtype = src->kernel_dtype;
155  dst->layer = src->layer;
156  dst->flags = src->flags;
157  if (mode != CK_IR_V2_LOWER_BACKWARD) {
158  dst->flags = (uint8_t)(dst->flags | CK_IR_V2_NODE_INFERENCE_ONLY);
159  }
160  dst->n_bindings = src->n_bindings;
161  for (int b = 0; b < src->n_bindings; ++b) {
162  dst->bindings[b].arg = ck_ir_v2_lower_strdup(src->bindings[b].arg);
163  dst->bindings[b].buffer = src->bindings[b].buffer;
164  }
165  dst->n_inputs = src->n_inputs;
166  dst->n_outputs = src->n_outputs;
167  for (int j = 0; j < dst->n_inputs; ++j) {
168  dst->inputs[j] = src->inputs[j];
169  }
170  for (int j = 0; j < dst->n_outputs; ++j) {
171  dst->outputs[j] = src->outputs[j];
172  }
173  }
174  return 0;
175 }
@ CK_IR_V2_NODE_INFERENCE_ONLY
Definition: ckernel_ir_v2.h:21
static int ck_ir_v2_lower_node_enabled(const CKIRV2Node *node, CKIRV2LowerMode mode)
@ CK_IR_V2_LOWER_BACKWARD
int32_t buffer
Definition: ckernel_ir_v2.h:37
CKIRV2Node * nodes
Definition: ckernel_ir_v2.h:64
uint8_t n_bindings
Definition: ckernel_ir_v2.h:48
int32_t inputs[8]
Definition: ckernel_ir_v2.h:49
char * condition
Definition: ckernel_ir_v2.h:44
uint16_t layer
Definition: ckernel_ir_v2.h:45
uint8_t flags
Definition: ckernel_ir_v2.h:46
CKIRV2Binding bindings[24]
Definition: ckernel_ir_v2.h:47
int32_t outputs[4]
Definition: ckernel_ir_v2.h:51
CKDataType kernel_dtype
Definition: ckernel_ir_v2.h:43
uint8_t n_outputs
Definition: ckernel_ir_v2.h:52
char * kernel
Definition: ckernel_ir_v2.h:42
uint8_t n_inputs
Definition: ckernel_ir_v2.h:50

References CKIRV2Binding::arg, CKIRV2Node::bindings, CKIRV2Binding::buffer, CK_IR_V2_LOWER_BACKWARD, ck_ir_v2_lower_node_enabled(), ck_ir_v2_lower_strdup(), CK_IR_V2_NODE_INFERENCE_ONLY, CKIRV2Node::condition, CKIRV2Node::flags, CKIRV2Node::inputs, CKIRV2Node::kernel, CKIRV2Node::kernel_dtype, CKIRV2Node::layer, CKIRV2Node::n_bindings, CKIRV2Node::n_inputs, CKIRV2Node::n_outputs, CKIRV2Graph::nodes, CKIRV2Graph::num_nodes, CKIRV2Node::op, and CKIRV2Node::outputs.

Referenced by ck_ir_v2_lower_graph().

◆ ck_ir_v2_lower_emit_json()

int ck_ir_v2_lower_emit_json ( const CKIRV2Graph input,
CKIRV2LowerMode  mode,
const char *  path 
)

Definition at line 219 of file ckernel_ir_v2_lower.c.

222 {
223  if (!input || !path) {
224  return -1;
225  }
226  CKIRV2Graph lowered = {0};
227  CKMemPlan plan = {0};
228  if (ck_ir_v2_lower_graph(input, mode, &lowered, &plan) != 0) {
229  return -1;
230  }
231  int tokens_override = (mode == CK_IR_V2_LOWER_DECODE) ? 1 : -1;
232  int base_context = (tokens_override >= 0) ? input->config.context_window : -1;
233  int rc = ck_ir_v2_serialize_json_with_plan(&lowered, &plan,
235  tokens_override,
236  base_context,
237  path);
238  ck_ir_v2_free(&lowered);
239  ck_mem_plan_free(&plan);
240  return rc;
241 }
int ck_ir_v2_serialize_json_with_plan(const CKIRV2Graph *graph, const struct CKMemPlan *plan, const char *mode, int tokens_override, int base_context_window, const char *path)
void ck_ir_v2_free(CKIRV2Graph *graph)
Definition: ckernel_ir_v2.c:34
const char * ck_ir_v2_lower_mode_name(CKIRV2LowerMode mode)
int ck_ir_v2_lower_graph(const CKIRV2Graph *input, CKIRV2LowerMode mode, CKIRV2Graph *output, CKMemPlan *plan)
@ CK_IR_V2_LOWER_DECODE
void ck_mem_plan_free(CKMemPlan *plan)
CKModelConfig config
Definition: ckernel_ir_v2.h:56
int context_window
Definition: ckernel_ir.h:30

References ck_ir_v2_free(), CK_IR_V2_LOWER_DECODE, ck_ir_v2_lower_graph(), ck_ir_v2_lower_mode_name(), ck_ir_v2_serialize_json_with_plan(), ck_mem_plan_free(), CKIRV2Graph::config, and CKModelConfig::context_window.

Referenced by main().

◆ ck_ir_v2_lower_graph()

int ck_ir_v2_lower_graph ( const CKIRV2Graph input,
CKIRV2LowerMode  mode,
CKIRV2Graph output,
CKMemPlan plan 
)

Definition at line 177 of file ckernel_ir_v2_lower.c.

181 {
182  if (!input || !output || !plan) {
183  return -1;
184  }
185  memset(output, 0, sizeof(*output));
186  memset(plan, 0, sizeof(*plan));
187 
188  output->config = input->config;
189  output->has_pos_emb = input->has_pos_emb;
190  output->tie_word_embeddings = input->tie_word_embeddings;
191  output->fused_qkv = input->fused_qkv;
192  output->gated_mlp = input->gated_mlp;
193 
194  if (ck_ir_v2_lower_copy_buffers(input, output) != 0 ||
195  ck_ir_v2_lower_copy_nodes(input, mode, output) != 0) {
196  ck_ir_v2_free(output);
197  return -1;
198  }
199 
200  int tokens_override = (mode == CK_IR_V2_LOWER_DECODE) ? 1 : -1;
201  int rc = 0;
202  if (mode == CK_IR_V2_LOWER_BACKWARD) {
203  rc = ck_mem_plan_build_training_with_tokens(output, plan,
205  tokens_override);
206  } else {
207  rc = ck_mem_plan_build_inference_with_tokens(output, plan,
209  tokens_override);
210  }
211  if (rc != 0) {
212  ck_ir_v2_free(output);
213  ck_mem_plan_free(plan);
214  return -1;
215  }
216  return 0;
217 }
static int ck_ir_v2_lower_copy_nodes(const CKIRV2Graph *input, CKIRV2LowerMode mode, CKIRV2Graph *output)
static int ck_ir_v2_lower_copy_buffers(const CKIRV2Graph *input, CKIRV2Graph *output)
int ck_mem_plan_build_training_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
#define CK_MEM_PLAN_DEFAULT_ALIGN
int ck_mem_plan_build_inference_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
int tie_word_embeddings
Definition: ckernel_ir_v2.h:58

References ck_ir_v2_free(), CK_IR_V2_LOWER_BACKWARD, ck_ir_v2_lower_copy_buffers(), ck_ir_v2_lower_copy_nodes(), CK_IR_V2_LOWER_DECODE, ck_mem_plan_build_inference_with_tokens(), ck_mem_plan_build_training_with_tokens(), CK_MEM_PLAN_DEFAULT_ALIGN, ck_mem_plan_free(), CKIRV2Graph::config, CKIRV2Graph::fused_qkv, CKIRV2Graph::gated_mlp, CKIRV2Graph::has_pos_emb, and CKIRV2Graph::tie_word_embeddings.

Referenced by ck_ir_v2_lower_emit_json().

◆ ck_ir_v2_lower_mode_from_string()

int ck_ir_v2_lower_mode_from_string ( const char *  name,
CKIRV2LowerMode out_mode 
)

Definition at line 36 of file ckernel_ir_v2_lower.c.

37 {
38  if (!name || !out_mode) {
39  return -1;
40  }
41  if (strcmp(name, "prefill") == 0) {
42  *out_mode = CK_IR_V2_LOWER_PREFILL;
43  return 0;
44  }
45  if (strcmp(name, "decode") == 0) {
46  *out_mode = CK_IR_V2_LOWER_DECODE;
47  return 0;
48  }
49  if (strcmp(name, "backward") == 0) {
50  *out_mode = CK_IR_V2_LOWER_BACKWARD;
51  return 0;
52  }
53  return -1;
54 }
@ CK_IR_V2_LOWER_PREFILL

References CK_IR_V2_LOWER_BACKWARD, CK_IR_V2_LOWER_DECODE, and CK_IR_V2_LOWER_PREFILL.

Referenced by ck_ir_v2_lower_node_enabled(), and main().

◆ ck_ir_v2_lower_mode_name()

const char* ck_ir_v2_lower_mode_name ( CKIRV2LowerMode  mode)

Definition at line 22 of file ckernel_ir_v2_lower.c.

23 {
24  switch (mode) {
26  return "prefill";
28  return "decode";
30  return "backward";
31  default:
32  return "unknown";
33  }
34 }

References CK_IR_V2_LOWER_BACKWARD, CK_IR_V2_LOWER_DECODE, and CK_IR_V2_LOWER_PREFILL.

Referenced by ck_ir_v2_lower_emit_json(), and main().

◆ ck_ir_v2_lower_node_enabled()

static int ck_ir_v2_lower_node_enabled ( const CKIRV2Node node,
CKIRV2LowerMode  mode 
)
static

Definition at line 56 of file ckernel_ir_v2_lower.c.

57 {
58  if (!node) {
59  return 0;
60  }
61  if (node->condition) {
62  const char *cond = node->condition;
63  const char *cur = cond;
64  while (*cur == ' ' || *cur == '\t') {
65  cur++;
66  }
67  if (strncmp(cur, "mode", 4) == 0) {
68  cur += 4;
69  while (*cur == ' ' || *cur == '\t') {
70  cur++;
71  }
72  if (cur[0] == '=' && cur[1] == '=') {
73  cur += 2;
74  while (*cur == ' ' || *cur == '\t') {
75  cur++;
76  }
77  char mode_buf[16] = {0};
78  int idx = 0;
79  while (*cur && idx < (int)(sizeof(mode_buf) - 1) &&
80  ((*cur >= 'a' && *cur <= 'z') || *cur == '_' || *cur == '-')) {
81  mode_buf[idx++] = *cur++;
82  }
83  mode_buf[idx] = '\0';
84  CKIRV2LowerMode target;
85  if (ck_ir_v2_lower_mode_from_string(mode_buf, &target) == 0) {
86  return mode == target;
87  }
88  }
89  }
90  if (strcmp(node->condition, "training_enabled") == 0 ||
91  strcmp(node->condition, "backward_only") == 0) {
92  return mode == CK_IR_V2_LOWER_BACKWARD;
93  }
94  if (strcmp(node->condition, "inference_only") == 0) {
95  return mode != CK_IR_V2_LOWER_BACKWARD;
96  }
97  if (strcmp(node->condition, "prefill_only") == 0) {
98  return mode == CK_IR_V2_LOWER_PREFILL;
99  }
100  if (strcmp(node->condition, "decode_only") == 0) {
101  return mode == CK_IR_V2_LOWER_DECODE;
102  }
103  }
104  return 1;
105 }
int ck_ir_v2_lower_mode_from_string(const char *name, CKIRV2LowerMode *out_mode)
CKIRV2LowerMode

References CK_IR_V2_LOWER_BACKWARD, CK_IR_V2_LOWER_DECODE, ck_ir_v2_lower_mode_from_string(), CK_IR_V2_LOWER_PREFILL, and CKIRV2Node::condition.

Referenced by ck_ir_v2_lower_copy_nodes().

◆ ck_ir_v2_lower_strdup()

static char* ck_ir_v2_lower_strdup ( const char *  s)
static

Definition at line 7 of file ckernel_ir_v2_lower.c.

8 {
9  if (!s) {
10  return NULL;
11  }
12  size_t len = strlen(s);
13  char *out = (char *)malloc(len + 1);
14  if (!out) {
15  return NULL;
16  }
17  memcpy(out, s, len);
18  out[len] = '\0';
19  return out;
20 }

Referenced by ck_ir_v2_lower_copy_buffers(), and ck_ir_v2_lower_copy_nodes().