← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_ir_v2_builder.c File Reference
#include "ckernel_ir_v2.h"
#include "ckernel_kernel_specs.h"
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

Go to the source code of this file.

Functions

int ck_ir_v2_apply_meta (const char *path, CKIRV2Graph *graph)
 
static int ck_ir_v2_apply_weight_dtypes (const char *json, const char *end, CKIRV2Graph *graph)
 
int ck_ir_v2_build_decoder (const CKModelConfig *cfg, CKIRV2Graph *graph)
 
int ck_ir_v2_build_decoder_backward (const CKIRV2Graph *forward, CKIRV2Graph *backward)
 
static int ck_ir_v2_copy_buffer_spec (const CKBufferSpec *spec, CKIRV2Buffer *out)
 
static void ck_ir_v2_copy_shape (CKDimToken *dst, const CKDimToken *src)
 
static int ck_ir_v2_find_buffer_index (const CKIRV2Graph *graph, const char *name)
 
static const CKKernelSpecck_ir_v2_find_kernel_spec (const char *name)
 
static const char * ck_ir_v2_find_key (const char *json, const char *key, const char *end)
 
static int ck_ir_v2_parse_bool (const char *json, const char *key, const char *end, int *out)
 
static CKDataType ck_ir_v2_parse_dtype (const char *s)
 
static int ck_ir_v2_parse_string (const char *start, const char *end, char **out)
 
static const char * ck_ir_v2_select_kernel (const CKKernelSpec *spec, CKDataType dtype, int backward)
 
static const char * ck_ir_v2_skip_ws (const char *cur, const char *end)
 
static char * ck_ir_v2_strdup (const char *s)
 

Function Documentation

◆ ck_ir_v2_apply_meta()

int ck_ir_v2_apply_meta ( const char *  path,
CKIRV2Graph graph 
)

Definition at line 220 of file ckernel_ir_v2_builder.c.

221 {
222  if (!path || !graph) {
223  return -1;
224  }
225  FILE *f = fopen(path, "rb");
226  if (!f) {
227  perror("ck_ir_v2_apply_meta: fopen");
228  return -1;
229  }
230  if (fseek(f, 0, SEEK_END) != 0) {
231  fclose(f);
232  return -1;
233  }
234  long len = ftell(f);
235  if (len < 0) {
236  fclose(f);
237  return -1;
238  }
239  if (fseek(f, 0, SEEK_SET) != 0) {
240  fclose(f);
241  return -1;
242  }
243  char *buf = (char *)malloc((size_t)len + 1);
244  if (!buf) {
245  fclose(f);
246  return -1;
247  }
248  size_t nread = fread(buf, 1, (size_t)len, f);
249  fclose(f);
250  buf[nread] = '\0';
251  const char *end = buf + nread;
252 
253  int val = 0;
254  if (ck_ir_v2_parse_bool(buf, "\"has_pos_emb\"", end, &val) == 0) {
255  graph->has_pos_emb = val ? 1 : 0;
256  }
257  if (ck_ir_v2_parse_bool(buf, "\"tie_word_embeddings\"", end, &val) == 0) {
258  graph->tie_word_embeddings = val ? 1 : 0;
259  if (graph->tie_word_embeddings == 0) {
260  int idx = ck_ir_v2_find_buffer_index(graph, "lm_head_weight");
261  if (idx >= 0) {
262  free(graph->buffers[idx].alias_of);
263  graph->buffers[idx].alias_of = NULL;
264  }
265  } else {
266  int idx = ck_ir_v2_find_buffer_index(graph, "lm_head_weight");
267  if (idx >= 0) {
268  free(graph->buffers[idx].alias_of);
269  graph->buffers[idx].alias_of = ck_ir_v2_strdup("token_emb");
270  }
271  }
272  }
273  if (ck_ir_v2_parse_bool(buf, "\"fused_qkv\"", end, &val) == 0) {
274  graph->fused_qkv = val ? 1 : 0;
275  }
276  if (ck_ir_v2_parse_bool(buf, "\"gated_mlp\"", end, &val) == 0) {
277  graph->gated_mlp = val ? 1 : 0;
278  }
279  (void)ck_ir_v2_apply_weight_dtypes(buf, end, graph);
280 
281  free(buf);
282  return 0;
283 }
static char * ck_ir_v2_strdup(const char *s)
static int ck_ir_v2_apply_weight_dtypes(const char *json, const char *end, CKIRV2Graph *graph)
static int ck_ir_v2_parse_bool(const char *json, const char *key, const char *end, int *out)
static int ck_ir_v2_find_buffer_index(const CKIRV2Graph *graph, const char *name)
char * alias_of
Definition: ckernel_ir_v2.h:31
int tie_word_embeddings
Definition: ckernel_ir_v2.h:58
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62
uint32_t end
Definition: utf8.c:215

References CKIRV2Buffer::alias_of, CKIRV2Graph::buffers, ck_ir_v2_apply_weight_dtypes(), ck_ir_v2_find_buffer_index(), ck_ir_v2_parse_bool(), ck_ir_v2_strdup(), end, CKIRV2Graph::fused_qkv, CKIRV2Graph::gated_mlp, CKIRV2Graph::has_pos_emb, and CKIRV2Graph::tie_word_embeddings.

Referenced by main().

◆ ck_ir_v2_apply_weight_dtypes()

static int ck_ir_v2_apply_weight_dtypes ( const char *  json,
const char *  end,
CKIRV2Graph graph 
)
static

Definition at line 174 of file ckernel_ir_v2_builder.c.

175 {
176  const char *start = ck_ir_v2_find_key(json, "\"weight_dtypes\"", end);
177  if (!start) {
178  return 0;
179  }
180  const char *open = strchr(start, '{');
181  if (!open || open >= end) {
182  return -1;
183  }
184  const char *cur = open + 1;
185  while (cur < end) {
186  cur = ck_ir_v2_skip_ws(cur, end);
187  if (cur >= end || *cur == '}') {
188  break;
189  }
190  char *key = NULL;
191  if (ck_ir_v2_parse_string(cur, end, &key) != 0) {
192  break;
193  }
194  cur = strchr(cur, ':');
195  if (!cur) {
196  free(key);
197  break;
198  }
199  cur = ck_ir_v2_skip_ws(cur + 1, end);
200  char *val = NULL;
201  if (ck_ir_v2_parse_string(cur, end, &val) != 0) {
202  free(key);
203  break;
204  }
205  int idx = ck_ir_v2_find_buffer_index(graph, key);
206  if (idx >= 0) {
207  graph->buffers[idx].dtype = ck_ir_v2_parse_dtype(val);
208  }
209  free(key);
210  free(val);
211  cur = strchr(cur, ',');
212  if (!cur) {
213  break;
214  }
215  cur++;
216  }
217  return 0;
218 }
static const char * ck_ir_v2_skip_ws(const char *cur, const char *end)
static int ck_ir_v2_parse_string(const char *start, const char *end, char **out)
static const char * ck_ir_v2_find_key(const char *json, const char *key, const char *end)
static CKDataType ck_ir_v2_parse_dtype(const char *s)
CKDataType dtype
Definition: ckernel_ir_v2.h:28
uint32_t start
Definition: utf8.c:214

References CKIRV2Graph::buffers, ck_ir_v2_find_buffer_index(), ck_ir_v2_find_key(), ck_ir_v2_parse_dtype(), ck_ir_v2_parse_string(), ck_ir_v2_skip_ws(), CKIRV2Buffer::dtype, end, and start.

Referenced by ck_ir_v2_apply_meta().

◆ ck_ir_v2_build_decoder()

int ck_ir_v2_build_decoder ( const CKModelConfig cfg,
CKIRV2Graph graph 
)

Definition at line 285 of file ckernel_ir_v2_builder.c.

286 {
287  if (!cfg || !graph) {
288  return -1;
289  }
290  memset(graph, 0, sizeof(*graph));
291  graph->config = *cfg;
292  graph->has_pos_emb = 1;
293  graph->tie_word_embeddings = -1;
294  graph->fused_qkv = -1;
295  graph->gated_mlp = -1;
296 
297  graph->num_buffers = (int)ck_decoder_buffer_count;
298  graph->buffers = (CKIRV2Buffer *)calloc((size_t)graph->num_buffers, sizeof(CKIRV2Buffer));
299  if (!graph->buffers) {
300  return -1;
301  }
302  for (int i = 0; i < graph->num_buffers; ++i) {
303  if (ck_ir_v2_copy_buffer_spec(&ck_decoder_buffers[i], &graph->buffers[i]) != 0) {
304  ck_ir_v2_free(graph);
305  return -1;
306  }
307  if (graph->buffers[i].name && strcmp(graph->buffers[i].name, "pos_emb") == 0) {
308  free(graph->buffers[i].condition);
309  graph->buffers[i].condition = ck_ir_v2_strdup("has_pos_emb");
310  }
311  }
312 
313  int plan_count = (int)ck_decoder_forward_plan_v2_count;
314  graph->num_nodes = cfg->num_layers * plan_count;
315  graph->nodes = (CKIRV2Node *)calloc((size_t)graph->num_nodes, sizeof(CKIRV2Node));
316  if (!graph->nodes) {
317  ck_ir_v2_free(graph);
318  return -1;
319  }
320 
321  int idx = 0;
322  for (int layer = 0; layer < cfg->num_layers; ++layer) {
323  for (int p = 0; p < plan_count; ++p) {
324  const CKPlanStepV2 *step = &ck_decoder_forward_plan_v2[p];
325  const CKKernelSpec *spec = ck_ir_v2_find_kernel_spec(step->kernel);
326  CKDataType dtype = spec ? spec->default_dtype : CK_DT_FP32;
327  const char *impl = ck_ir_v2_select_kernel(spec, dtype, 0);
328  CKIRV2Node *node = &graph->nodes[idx++];
329  node->layer = (uint16_t)layer;
330  node->op = ck_ir_v2_strdup(step->kernel);
331  node->kernel = ck_ir_v2_strdup(impl ? impl : step->kernel);
332  node->kernel_dtype = dtype;
333  node->condition = ck_ir_v2_strdup(step->condition);
334  node->flags = 0;
335  node->n_bindings = 0;
336  if (step->bindings && step->num_bindings > 0) {
337  int limit = (int)step->num_bindings;
338  if (limit > CK_IR_V2_MAX_BINDINGS) {
339  limit = CK_IR_V2_MAX_BINDINGS;
340  }
341  for (int b = 0; b < limit; ++b) {
342  const CKPlanBinding *binding = &step->bindings[b];
343  node->bindings[node->n_bindings].arg = ck_ir_v2_strdup(binding->arg);
344  node->bindings[node->n_bindings].buffer =
345  ck_ir_v2_find_buffer_index(graph, binding->buffer);
346  node->n_bindings++;
347  }
348  }
349  node->n_inputs = 0;
350  node->n_outputs = 0;
351  }
352  }
353  return 0;
354 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
#define CK_IR_V2_MAX_BINDINGS
Definition: ckernel_ir_v2.h:16
void ck_ir_v2_free(CKIRV2Graph *graph)
Definition: ckernel_ir_v2.c:34
static int ck_ir_v2_copy_buffer_spec(const CKBufferSpec *spec, CKIRV2Buffer *out)
static const char * ck_ir_v2_select_kernel(const CKKernelSpec *spec, CKDataType dtype, int backward)
static const CKKernelSpec * ck_ir_v2_find_kernel_spec(const char *name)
const CKPlanStepV2 ck_decoder_forward_plan_v2[]
const size_t ck_decoder_forward_plan_v2_count
const CKBufferSpec ck_decoder_buffers[]
const size_t ck_decoder_buffer_count
int32_t buffer
Definition: ckernel_ir_v2.h:37
char * condition
Definition: ckernel_ir_v2.h:32
CKModelConfig config
Definition: ckernel_ir_v2.h:56
CKIRV2Node * nodes
Definition: ckernel_ir_v2.h:64
uint8_t n_bindings
Definition: ckernel_ir_v2.h:48
char * condition
Definition: ckernel_ir_v2.h:44
uint16_t layer
Definition: ckernel_ir_v2.h:45
uint8_t flags
Definition: ckernel_ir_v2.h:46
CKIRV2Binding bindings[24]
Definition: ckernel_ir_v2.h:47
CKDataType kernel_dtype
Definition: ckernel_ir_v2.h:43
uint8_t n_outputs
Definition: ckernel_ir_v2.h:52
char * kernel
Definition: ckernel_ir_v2.h:42
uint8_t n_inputs
Definition: ckernel_ir_v2.h:50
CKDataType default_dtype
const char * buffer
const char * kernel
const char * condition
const CKPlanBinding * bindings

References CKIRV2Binding::arg, CKPlanBinding::arg, CKIRV2Node::bindings, CKPlanStepV2::bindings, CKIRV2Binding::buffer, CKPlanBinding::buffer, CKIRV2Graph::buffers, ck_decoder_buffer_count, ck_decoder_buffers, ck_decoder_forward_plan_v2, ck_decoder_forward_plan_v2_count, CK_DT_FP32, ck_ir_v2_copy_buffer_spec(), ck_ir_v2_find_buffer_index(), ck_ir_v2_find_kernel_spec(), ck_ir_v2_free(), CK_IR_V2_MAX_BINDINGS, ck_ir_v2_select_kernel(), ck_ir_v2_strdup(), CKIRV2Buffer::condition, CKIRV2Node::condition, CKPlanStepV2::condition, CKIRV2Graph::config, CKKernelSpec::default_dtype, CKIRV2Node::flags, CKIRV2Graph::fused_qkv, CKIRV2Graph::gated_mlp, CKIRV2Graph::has_pos_emb, CKIRV2Node::kernel, CKPlanStepV2::kernel, CKIRV2Node::kernel_dtype, CKIRV2Node::layer, CKIRV2Node::n_bindings, CKIRV2Node::n_inputs, CKIRV2Node::n_outputs, CKIRV2Buffer::name, CKIRV2Graph::nodes, CKPlanStepV2::num_bindings, CKIRV2Graph::num_buffers, CKModelConfig::num_layers, CKIRV2Graph::num_nodes, CKIRV2Node::op, and CKIRV2Graph::tie_word_embeddings.

Referenced by main().

◆ ck_ir_v2_build_decoder_backward()

int ck_ir_v2_build_decoder_backward ( const CKIRV2Graph forward,
CKIRV2Graph backward 
)

Definition at line 356 of file ckernel_ir_v2_builder.c.

357 {
358  if (!forward || !backward) {
359  return -1;
360  }
361  memset(backward, 0, sizeof(*backward));
362  backward->config = forward->config;
363 
364  backward->num_buffers = forward->num_buffers;
365  backward->buffers = (CKIRV2Buffer *)calloc((size_t)backward->num_buffers, sizeof(CKIRV2Buffer));
366  if (!backward->buffers) {
367  return -1;
368  }
369  for (int i = 0; i < backward->num_buffers; ++i) {
370  CKBufferSpec spec = {0};
371  const CKIRV2Buffer *src = &forward->buffers[i];
372  spec.name = src->name;
373  spec.scope = src->scope;
374  spec.role = src->role;
375  spec.dtype = src->dtype;
376  spec.optional = src->optional;
377  spec.alias_of = src->alias_of;
378  spec.condition = src->condition;
379  memcpy(spec.shape, src->shape, sizeof(spec.shape));
380  if (ck_ir_v2_copy_buffer_spec(&spec, &backward->buffers[i]) != 0) {
381  ck_ir_v2_free(backward);
382  return -1;
383  }
384  }
385 
386  int plan_count = (int)ck_decoder_backward_plan_v2_count;
387  backward->num_nodes = forward->config.num_layers * plan_count;
388  backward->nodes = (CKIRV2Node *)calloc((size_t)backward->num_nodes, sizeof(CKIRV2Node));
389  if (!backward->nodes) {
390  ck_ir_v2_free(backward);
391  return -1;
392  }
393 
394  int idx = 0;
395  for (int layer = 0; layer < forward->config.num_layers; ++layer) {
396  for (int p = 0; p < plan_count; ++p) {
397  const CKPlanStepV2 *step = &ck_decoder_backward_plan_v2[p];
398  const CKKernelSpec *spec = ck_ir_v2_find_kernel_spec(step->kernel);
399  CKDataType dtype = spec ? spec->default_dtype : CK_DT_FP32;
400  const char *impl = ck_ir_v2_select_kernel(spec, dtype, 1);
401  CKIRV2Node *node = &backward->nodes[idx++];
402  node->layer = (uint16_t)layer;
403  node->op = ck_ir_v2_strdup(step->kernel);
404  node->kernel = ck_ir_v2_strdup(impl ? impl : step->kernel);
405  node->kernel_dtype = dtype;
406  node->condition = ck_ir_v2_strdup(step->condition);
407  node->flags = 0;
408  node->n_bindings = 0;
409  if (step->bindings && step->num_bindings > 0) {
410  int limit = (int)step->num_bindings;
411  if (limit > CK_IR_V2_MAX_BINDINGS) {
412  limit = CK_IR_V2_MAX_BINDINGS;
413  }
414  for (int b = 0; b < limit; ++b) {
415  const CKPlanBinding *binding = &step->bindings[b];
416  node->bindings[node->n_bindings].arg = ck_ir_v2_strdup(binding->arg);
417  node->bindings[node->n_bindings].buffer =
418  ck_ir_v2_find_buffer_index(backward, binding->buffer);
419  node->n_bindings++;
420  }
421  }
422  node->n_inputs = 0;
423  node->n_outputs = 0;
424  }
425  }
426  return 0;
427 }
const size_t ck_decoder_backward_plan_v2_count
const CKPlanStepV2 ck_decoder_backward_plan_v2[]
CKBufferRole role
const char * alias_of
const char * condition
CKBufferScope scope
const char * name
CKDimToken shape[4]
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
CKBufferRole role
Definition: ckernel_ir_v2.h:27
CKBufferScope scope
Definition: ckernel_ir_v2.h:26

References CKIRV2Buffer::alias_of, CKBufferSpec::alias_of, CKIRV2Binding::arg, CKPlanBinding::arg, CKIRV2Node::bindings, CKPlanStepV2::bindings, CKIRV2Binding::buffer, CKPlanBinding::buffer, CKIRV2Graph::buffers, ck_decoder_backward_plan_v2, ck_decoder_backward_plan_v2_count, CK_DT_FP32, ck_ir_v2_copy_buffer_spec(), ck_ir_v2_find_buffer_index(), ck_ir_v2_find_kernel_spec(), ck_ir_v2_free(), CK_IR_V2_MAX_BINDINGS, ck_ir_v2_select_kernel(), ck_ir_v2_strdup(), CKIRV2Buffer::condition, CKIRV2Node::condition, CKBufferSpec::condition, CKPlanStepV2::condition, CKIRV2Graph::config, CKKernelSpec::default_dtype, CKIRV2Buffer::dtype, CKBufferSpec::dtype, CKIRV2Node::flags, CKIRV2Node::kernel, CKPlanStepV2::kernel, CKIRV2Node::kernel_dtype, CKIRV2Node::layer, CKIRV2Node::n_bindings, CKIRV2Node::n_inputs, CKIRV2Node::n_outputs, CKIRV2Buffer::name, CKBufferSpec::name, CKIRV2Graph::nodes, CKPlanStepV2::num_bindings, CKIRV2Graph::num_buffers, CKModelConfig::num_layers, CKIRV2Graph::num_nodes, CKIRV2Node::op, CKIRV2Buffer::optional, CKBufferSpec::optional, CKIRV2Buffer::role, CKBufferSpec::role, CKIRV2Buffer::scope, CKBufferSpec::scope, CKIRV2Buffer::shape, and CKBufferSpec::shape.

Referenced by main().

◆ ck_ir_v2_copy_buffer_spec()

static int ck_ir_v2_copy_buffer_spec ( const CKBufferSpec spec,
CKIRV2Buffer out 
)
static

Definition at line 27 of file ckernel_ir_v2_builder.c.

28 {
29  if (!spec || !out) {
30  return -1;
31  }
32  memset(out, 0, sizeof(*out));
33  out->name = ck_ir_v2_strdup(spec->name);
34  out->scope = spec->scope;
35  out->role = spec->role;
36  out->dtype = spec->dtype;
37  out->optional = spec->optional;
38  out->alias_of = ck_ir_v2_strdup(spec->alias_of);
39  out->condition = ck_ir_v2_strdup(spec->condition);
40  ck_ir_v2_copy_shape(out->shape, spec->shape);
41  return 0;
42 }
static void ck_ir_v2_copy_shape(CKDimToken *dst, const CKDimToken *src)

References CKIRV2Buffer::alias_of, CKBufferSpec::alias_of, ck_ir_v2_copy_shape(), ck_ir_v2_strdup(), CKIRV2Buffer::condition, CKBufferSpec::condition, CKIRV2Buffer::dtype, CKBufferSpec::dtype, CKIRV2Buffer::name, CKBufferSpec::name, CKIRV2Buffer::optional, CKBufferSpec::optional, CKIRV2Buffer::role, CKBufferSpec::role, CKIRV2Buffer::scope, CKBufferSpec::scope, CKIRV2Buffer::shape, and CKBufferSpec::shape.

Referenced by ck_ir_v2_build_decoder(), and ck_ir_v2_build_decoder_backward().

◆ ck_ir_v2_copy_shape()

static void ck_ir_v2_copy_shape ( CKDimToken dst,
const CKDimToken src 
)
static

Definition at line 22 of file ckernel_ir_v2_builder.c.

23 {
24  memcpy(dst, src, sizeof(CKDimToken) * CK_IR_V2_MAX_DIMS);
25 }
#define CK_IR_V2_MAX_DIMS
Definition: ckernel_ir_v2.h:13

References CK_IR_V2_MAX_DIMS.

Referenced by ck_ir_v2_copy_buffer_spec().

◆ ck_ir_v2_find_buffer_index()

static int ck_ir_v2_find_buffer_index ( const CKIRV2Graph graph,
const char *  name 
)
static

Definition at line 44 of file ckernel_ir_v2_builder.c.

45 {
46  if (!graph || !name) {
47  return -1;
48  }
49  for (int i = 0; i < graph->num_buffers; ++i) {
50  if (graph->buffers[i].name && strcmp(graph->buffers[i].name, name) == 0) {
51  return i;
52  }
53  }
54  return -1;
55 }

References CKIRV2Graph::buffers, CKIRV2Buffer::name, and CKIRV2Graph::num_buffers.

Referenced by ck_ir_v2_apply_meta(), ck_ir_v2_apply_weight_dtypes(), ck_ir_v2_build_decoder(), and ck_ir_v2_build_decoder_backward().

◆ ck_ir_v2_find_kernel_spec()

static const CKKernelSpec* ck_ir_v2_find_kernel_spec ( const char *  name)
static

Definition at line 57 of file ckernel_ir_v2_builder.c.

58 {
59  if (!name) {
60  return NULL;
61  }
62  for (size_t i = 0; i < ck_kernel_spec_count; ++i) {
63  if (strcmp(ck_kernel_specs[i].name, name) == 0) {
64  return &ck_kernel_specs[i];
65  }
66  }
67  return NULL;
68 }
const CKKernelSpec ck_kernel_specs[]
const size_t ck_kernel_spec_count

References ck_kernel_spec_count, and ck_kernel_specs.

Referenced by ck_ir_v2_build_decoder(), and ck_ir_v2_build_decoder_backward().

◆ ck_ir_v2_find_key()

static const char* ck_ir_v2_find_key ( const char *  json,
const char *  key,
const char *  end 
)
static

Definition at line 91 of file ckernel_ir_v2_builder.c.

92 {
93  size_t key_len = strlen(key);
94  const char *cur = json;
95  while (cur + key_len < end) {
96  if (memcmp(cur, key, key_len) == 0) {
97  return cur;
98  }
99  cur++;
100  }
101  return NULL;
102 }

References end.

Referenced by ck_ir_v2_apply_weight_dtypes(), and ck_ir_v2_parse_bool().

◆ ck_ir_v2_parse_bool()

static int ck_ir_v2_parse_bool ( const char *  json,
const char *  key,
const char *  end,
int *  out 
)
static

Definition at line 139 of file ckernel_ir_v2_builder.c.

140 {
141  const char *p = ck_ir_v2_find_key(json, key, end);
142  if (!p) {
143  return -1;
144  }
145  p = strchr(p, ':');
146  if (!p || p >= end) {
147  return -1;
148  }
149  p = ck_ir_v2_skip_ws(p + 1, end);
150  if (p + 4 <= end && memcmp(p, "true", 4) == 0) {
151  *out = 1;
152  return 0;
153  }
154  if (p + 5 <= end && memcmp(p, "false", 5) == 0) {
155  *out = 0;
156  return 0;
157  }
158  return -1;
159 }

References ck_ir_v2_find_key(), ck_ir_v2_skip_ws(), and end.

Referenced by ck_ir_v2_apply_meta().

◆ ck_ir_v2_parse_dtype()

static CKDataType ck_ir_v2_parse_dtype ( const char *  s)
static

Definition at line 161 of file ckernel_ir_v2_builder.c.

162 {
163  if (!s) return CK_DT_FP32;
164  if (strcmp(s, "fp32") == 0) return CK_DT_FP32;
165  if (strcmp(s, "bf16") == 0) return CK_DT_BF16;
166  if (strcmp(s, "fp16") == 0) return CK_DT_FP16;
167  if (strcmp(s, "q4_0") == 0) return CK_DT_Q4_0;
168  if (strcmp(s, "q4_k") == 0) return CK_DT_Q4_K;
169  if (strcmp(s, "q6_k") == 0) return CK_DT_Q6_K;
170  if (strcmp(s, "q8_0") == 0) return CK_DT_Q8_0;
171  return CK_DT_FP32;
172 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q4_0
Definition: ckernel_dtype.h:38
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_FP16
Definition: ckernel_dtype.h:31
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
@ CK_DT_BF16
Definition: ckernel_dtype.h:30

References CK_DT_BF16, CK_DT_FP16, CK_DT_FP32, CK_DT_Q4_0, CK_DT_Q4_K, CK_DT_Q6_K, and CK_DT_Q8_0.

Referenced by ck_ir_v2_apply_weight_dtypes().

◆ ck_ir_v2_parse_string()

static int ck_ir_v2_parse_string ( const char *  start,
const char *  end,
char **  out 
)
static

Definition at line 112 of file ckernel_ir_v2_builder.c.

113 {
114  if (!start || start >= end || *start != '"') {
115  return -1;
116  }
117  const char *cur = start + 1;
118  while (cur < end && *cur != '"') {
119  if (*cur == '\\' && (cur + 1) < end) {
120  cur += 2;
121  continue;
122  }
123  cur++;
124  }
125  if (cur >= end || *cur != '"') {
126  return -1;
127  }
128  size_t len = (size_t)(cur - (start + 1));
129  char *buf = (char *)malloc(len + 1);
130  if (!buf) {
131  return -1;
132  }
133  memcpy(buf, start + 1, len);
134  buf[len] = '\0';
135  *out = buf;
136  return 0;
137 }

References end, and start.

Referenced by ck_ir_v2_apply_weight_dtypes().

◆ ck_ir_v2_select_kernel()

static const char* ck_ir_v2_select_kernel ( const CKKernelSpec spec,
CKDataType  dtype,
int  backward 
)
static

Definition at line 70 of file ckernel_ir_v2_builder.c.

71 {
72  if (!spec) {
73  return NULL;
74  }
75  if (dtype < 0 || dtype >= CK_DT_COUNT) {
76  dtype = spec->default_dtype;
77  }
78  const char *name = backward ? spec->backward[dtype] : spec->forward[dtype];
79  if (name && name[0]) {
80  return name;
81  }
82  for (int i = 0; i < CK_DT_COUNT; ++i) {
83  name = backward ? spec->backward[i] : spec->forward[i];
84  if (name && name[0]) {
85  return name;
86  }
87  }
88  return spec->name;
89 }
@ CK_DT_COUNT
Definition: ckernel_dtype.h:48
const char * name
const char * backward[CK_DT_COUNT]
const char * forward[CK_DT_COUNT]

References CKKernelSpec::backward, CK_DT_COUNT, CKKernelSpec::default_dtype, CKKernelSpec::forward, and CKKernelSpec::name.

Referenced by ck_ir_v2_build_decoder(), and ck_ir_v2_build_decoder_backward().

◆ ck_ir_v2_skip_ws()

static const char* ck_ir_v2_skip_ws ( const char *  cur,
const char *  end 
)
static

Definition at line 104 of file ckernel_ir_v2_builder.c.

105 {
106  while (cur < end && (*cur == ' ' || *cur == '\n' || *cur == '\r' || *cur == '\t')) {
107  cur++;
108  }
109  return cur;
110 }

References end.

Referenced by ck_ir_v2_apply_weight_dtypes(), and ck_ir_v2_parse_bool().

◆ ck_ir_v2_strdup()

static char* ck_ir_v2_strdup ( const char *  s)
static

Definition at line 8 of file ckernel_ir_v2_builder.c.

9 {
10  if (!s) {
11  return NULL;
12  }
13  size_t len = strlen(s);
14  char *out = (char *)malloc(len + 1);
15  if (!out) {
16  return NULL;
17  }
18  memcpy(out, s, len + 1);
19  return out;
20 }

Referenced by ck_ir_v2_apply_meta(), ck_ir_v2_build_decoder(), ck_ir_v2_build_decoder_backward(), and ck_ir_v2_copy_buffer_spec().