← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_ir_v2.c File Reference
#include "ckernel_ir_v2.h"
#include "ckernel_mem_plan.h"
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

Go to the source code of this file.

Functions

static size_t ck_ir_v2_align_up_bytes (size_t n, size_t align)
 
static size_t ck_ir_v2_align_up_elems (size_t elems, size_t elem_bytes, size_t align_bytes)
 
static CKDimKind ck_ir_v2_dim_kind_from_name (const char *name)
 
static const char * ck_ir_v2_dim_name (CKDimKind dim)
 
static const char * ck_ir_v2_dtype_name (CKDataType dtype)
 
static void ck_ir_v2_emit_dimensions (FILE *out, const CKModelConfig *cfg, const CKIRV2AlignInfo *align, int tokens_override)
 
static void ck_ir_v2_emit_memory_plan (FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan)
 
static void ck_ir_v2_emit_resolved_shape (FILE *out, const CKModelConfig *cfg, const CKIRV2AlignInfo *align, const CKDimToken *shape, int tokens_override)
 
static int ck_ir_v2_emit_shape (FILE *out, const CKDimToken *shape)
 
static const char * ck_ir_v2_find_array_end (const char *open, const char *end)
 
static int ck_ir_v2_find_buffer_index (const CKIRV2Graph *graph, const char *name)
 
static const CKBufferSpecck_ir_v2_find_buffer_spec (const char *name)
 
static const char * ck_ir_v2_find_key (const char *json, const char *key, const char *end)
 
void ck_ir_v2_free (CKIRV2Graph *graph)
 
static void ck_ir_v2_free_buffer (CKIRV2Buffer *buf)
 
static void ck_ir_v2_free_node (CKIRV2Node *node)
 
static const char * ck_ir_v2_mem_arena_name (CKMemArenaKind arena)
 
static const char * ck_ir_v2_next_object (const char *cur, const char *end, const char **obj_start, const char **obj_end)
 
static int ck_ir_v2_parse_bindings (const char *obj_start, const char *obj_end, CKIRV2Graph *graph, CKIRV2Node *node)
 
static int ck_ir_v2_parse_bool (const char *json, const char *key, const char *end, int *out_val)
 
static int ck_ir_v2_parse_buffers (const char *json, const char *end, CKIRV2Graph *graph)
 
static CKDimKind ck_ir_v2_parse_dim_kind (const char *obj_start, const char *obj_end)
 
static CKDataType ck_ir_v2_parse_dtype (const char *s)
 
static int ck_ir_v2_parse_float (const char *json, const char *key, const char *end, float *out_val)
 
static int ck_ir_v2_parse_int (const char *json, const char *key, const char *end, int *out_val)
 
int ck_ir_v2_parse_json (const char *path, CKIRV2Graph *graph)
 
static int ck_ir_v2_parse_nodes (const char *json, const char *end, CKIRV2Graph *graph)
 
static CKBufferRole ck_ir_v2_parse_role (const char *s)
 
static CKBufferScope ck_ir_v2_parse_scope (const char *s)
 
static int ck_ir_v2_parse_shape (const char *obj_start, const char *obj_end, CKDimToken *shape_out)
 
static int ck_ir_v2_parse_string (const char *start, const char *end, char **out_str)
 
static int ck_ir_v2_parse_string_field (const char *json, const char *key, const char *end, char **out_str)
 
static void ck_ir_v2_resolve_align (const CKModelConfig *cfg, size_t alignment_bytes, CKIRV2AlignInfo *align)
 
static size_t ck_ir_v2_resolve_dim_value (const CKModelConfig *cfg, const CKIRV2AlignInfo *align, CKDimKind dim, int tokens_override)
 
static const char * ck_ir_v2_role_name (CKBufferRole role)
 
static const char * ck_ir_v2_scope_name (CKBufferScope scope)
 
int ck_ir_v2_serialize_json (const CKIRV2Graph *graph, const char *path)
 
static int ck_ir_v2_serialize_json_internal (const CKIRV2Graph *graph, const CKMemPlan *plan, const char *mode, int tokens_override, int base_context_window, const char *path)
 
int ck_ir_v2_serialize_json_with_plan (const CKIRV2Graph *graph, const CKMemPlan *plan, const char *mode, int tokens_override, int base_context_window, const char *path)
 
static const char * ck_ir_v2_skip_string (const char *cur, const char *end)
 

Function Documentation

◆ ck_ir_v2_align_up_bytes()

static size_t ck_ir_v2_align_up_bytes ( size_t  n,
size_t  align 
)
static

Definition at line 164 of file ckernel_ir_v2.c.

165 {
166  if (align == 0) {
167  return n;
168  }
169  return (n + align - 1) & ~(align - 1);
170 }

Referenced by ck_ir_v2_align_up_elems().

◆ ck_ir_v2_align_up_elems()

static size_t ck_ir_v2_align_up_elems ( size_t  elems,
size_t  elem_bytes,
size_t  align_bytes 
)
static

Definition at line 172 of file ckernel_ir_v2.c.

173 {
174  size_t bytes = elems * elem_bytes;
175  bytes = ck_ir_v2_align_up_bytes(bytes, align_bytes);
176  return bytes / elem_bytes;
177 }
static size_t ck_ir_v2_align_up_bytes(size_t n, size_t align)

References ck_ir_v2_align_up_bytes().

Referenced by ck_ir_v2_resolve_align().

◆ ck_ir_v2_dim_kind_from_name()

static CKDimKind ck_ir_v2_dim_kind_from_name ( const char *  name)
static

Definition at line 138 of file ckernel_ir_v2.c.

139 {
140  if (!name) {
141  return CK_DIM_END;
142  }
143  if (strcmp(name, "tokens") == 0) return CK_DIM_TOKENS;
144  if (strcmp(name, "embed") == 0) return CK_DIM_EMBED;
145  if (strcmp(name, "aligned_embed") == 0) return CK_DIM_ALIGNED_EMBED;
146  if (strcmp(name, "head_dim") == 0) return CK_DIM_HEAD_DIM;
147  if (strcmp(name, "aligned_head") == 0) return CK_DIM_ALIGNED_HEAD;
148  if (strcmp(name, "num_heads") == 0) return CK_DIM_NUM_HEADS;
149  if (strcmp(name, "num_kv_heads") == 0) return CK_DIM_NUM_KV_HEADS;
150  if (strcmp(name, "aligned_ctx") == 0) return CK_DIM_ALIGNED_CTX;
151  if (strcmp(name, "intermediate") == 0) return CK_DIM_INTERMEDIATE;
152  if (strcmp(name, "aligned_intermediate") == 0) return CK_DIM_ALIGNED_INTERMEDIATE;
153  if (strcmp(name, "vocab") == 0) return CK_DIM_VOCAB;
154  return CK_DIM_END;
155 }
@ CK_DIM_ALIGNED_INTERMEDIATE
@ CK_DIM_NUM_HEADS
@ CK_DIM_ALIGNED_EMBED
@ CK_DIM_TOKENS
@ CK_DIM_INTERMEDIATE
@ CK_DIM_ALIGNED_CTX
@ CK_DIM_END
@ CK_DIM_ALIGNED_HEAD
@ CK_DIM_HEAD_DIM
@ CK_DIM_NUM_KV_HEADS
@ CK_DIM_VOCAB
@ CK_DIM_EMBED

References CK_DIM_ALIGNED_CTX, CK_DIM_ALIGNED_EMBED, CK_DIM_ALIGNED_HEAD, CK_DIM_ALIGNED_INTERMEDIATE, CK_DIM_EMBED, CK_DIM_END, CK_DIM_HEAD_DIM, CK_DIM_INTERMEDIATE, CK_DIM_NUM_HEADS, CK_DIM_NUM_KV_HEADS, CK_DIM_TOKENS, and CK_DIM_VOCAB.

Referenced by ck_ir_v2_parse_dim_kind().

◆ ck_ir_v2_dim_name()

static const char* ck_ir_v2_dim_name ( CKDimKind  dim)
static

Definition at line 108 of file ckernel_ir_v2.c.

109 {
110  switch (dim) {
111  case CK_DIM_TOKENS:
112  return "tokens";
113  case CK_DIM_EMBED:
114  return "embed";
116  return "aligned_embed";
117  case CK_DIM_HEAD_DIM:
118  return "head_dim";
119  case CK_DIM_ALIGNED_HEAD:
120  return "aligned_head";
121  case CK_DIM_NUM_HEADS:
122  return "num_heads";
123  case CK_DIM_NUM_KV_HEADS:
124  return "num_kv_heads";
125  case CK_DIM_ALIGNED_CTX:
126  return "aligned_ctx";
127  case CK_DIM_INTERMEDIATE:
128  return "intermediate";
130  return "aligned_intermediate";
131  case CK_DIM_VOCAB:
132  return "vocab";
133  default:
134  return "unknown";
135  }
136 }

References CK_DIM_ALIGNED_CTX, CK_DIM_ALIGNED_EMBED, CK_DIM_ALIGNED_HEAD, CK_DIM_ALIGNED_INTERMEDIATE, CK_DIM_EMBED, CK_DIM_HEAD_DIM, CK_DIM_INTERMEDIATE, CK_DIM_NUM_HEADS, CK_DIM_NUM_KV_HEADS, CK_DIM_TOKENS, and CK_DIM_VOCAB.

Referenced by ck_ir_v2_emit_dimensions(), and ck_ir_v2_emit_shape().

◆ ck_ir_v2_dtype_name()

static const char* ck_ir_v2_dtype_name ( CKDataType  dtype)
static

Definition at line 86 of file ckernel_ir_v2.c.

87 {
88  switch (dtype) {
89  case CK_DT_FP32:
90  return "fp32";
91  case CK_DT_BF16:
92  return "bf16";
93  case CK_DT_FP16:
94  return "fp16";
95  case CK_DT_Q4_0:
96  return "q4_0";
97  case CK_DT_Q4_K:
98  return "q4_k";
99  case CK_DT_Q6_K:
100  return "q6_k";
101  case CK_DT_Q8_0:
102  return "q8_0";
103  default:
104  return "unknown";
105  }
106 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q4_0
Definition: ckernel_dtype.h:38
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
@ CK_DT_FP16
Definition: ckernel_dtype.h:31
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
@ CK_DT_BF16
Definition: ckernel_dtype.h:30

References CK_DT_BF16, CK_DT_FP16, CK_DT_FP32, CK_DT_Q4_0, CK_DT_Q4_K, CK_DT_Q6_K, and CK_DT_Q8_0.

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_emit_dimensions()

static void ck_ir_v2_emit_dimensions ( FILE *  out,
const CKModelConfig cfg,
const CKIRV2AlignInfo *  align,
int  tokens_override 
)
static

Definition at line 286 of file ckernel_ir_v2.c.

290 {
291  fprintf(out, " \"dimensions\": [\n");
292  CKDimKind dims[] = {
294  CK_DIM_EMBED,
304  };
305  size_t dim_count = sizeof(dims) / sizeof(dims[0]);
306  for (size_t i = 0; i < dim_count; ++i) {
307  CKDimKind dim = dims[i];
308  size_t value = ck_ir_v2_resolve_dim_value(cfg, align, dim, tokens_override);
309  fprintf(out,
310  " {\"id\": %d, \"name\": \"%s\", \"value\": %zu}%s\n",
311  (int)dim,
312  ck_ir_v2_dim_name(dim),
313  value,
314  (i + 1 == dim_count) ? "" : ",");
315  }
316  fprintf(out, " ],\n");
317 }
static const char * ck_ir_v2_dim_name(CKDimKind dim)
static size_t ck_ir_v2_resolve_dim_value(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, CKDimKind dim, int tokens_override)

References CK_DIM_ALIGNED_CTX, CK_DIM_ALIGNED_EMBED, CK_DIM_ALIGNED_HEAD, CK_DIM_ALIGNED_INTERMEDIATE, CK_DIM_EMBED, CK_DIM_HEAD_DIM, CK_DIM_INTERMEDIATE, CK_DIM_NUM_HEADS, CK_DIM_NUM_KV_HEADS, CK_DIM_TOKENS, CK_DIM_VOCAB, ck_ir_v2_dim_name(), and ck_ir_v2_resolve_dim_value().

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_emit_memory_plan()

static void ck_ir_v2_emit_memory_plan ( FILE *  out,
const CKIRV2Graph graph,
const CKMemPlan plan 
)
static

Definition at line 333 of file ckernel_ir_v2.c.

336 {
337  if (!plan) {
338  return;
339  }
340  fprintf(out, " \"memory_plan\": {\n");
341  fprintf(out, " \"alignment_bytes\": %zu,\n", plan->alignment_bytes);
342  fprintf(out, " \"total_bytes\": {\n");
343  fprintf(out, " \"weights\": %zu,\n", plan->total_bytes[CK_MEM_ARENA_WEIGHTS]);
344  fprintf(out, " \"activations\": %zu,\n", plan->total_bytes[CK_MEM_ARENA_ACTIVATIONS]);
345  fprintf(out, " \"grads\": %zu\n", plan->total_bytes[CK_MEM_ARENA_GRADS]);
346  fprintf(out, " },\n");
347  fprintf(out, " \"buffers\": [\n");
348  for (int i = 0; i < graph->num_buffers; ++i) {
349  const CKMemSpan *span = &plan->spans[i];
350  const char *name = graph->buffers[i].name ? graph->buffers[i].name : "";
351  int enabled = span->size_bytes > 0;
352  fprintf(out,
353  " {\"name\": \"%s\", \"arena\": \"%s\", \"offset_bytes\": %zu, \"size_bytes\": %zu, \"enabled\": %s}%s\n",
354  name,
356  span->offset_bytes,
357  span->size_bytes,
358  enabled ? "true" : "false",
359  (i + 1 == graph->num_buffers) ? "" : ",");
360  }
361  fprintf(out, " ]\n");
362  fprintf(out, " },\n");
363 }
static const char * ck_ir_v2_mem_arena_name(CKMemArenaKind arena)
@ CK_MEM_ARENA_GRADS
@ CK_MEM_ARENA_WEIGHTS
@ CK_MEM_ARENA_ACTIVATIONS
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62
CKMemSpan * spans
size_t alignment_bytes
size_t total_bytes[CK_MEM_ARENA_COUNT]
size_t size_bytes
CKMemArenaKind arena
size_t offset_bytes

References CKMemPlan::alignment_bytes, CKMemSpan::arena, CKIRV2Graph::buffers, ck_ir_v2_mem_arena_name(), CK_MEM_ARENA_ACTIVATIONS, CK_MEM_ARENA_GRADS, CK_MEM_ARENA_WEIGHTS, CKIRV2Buffer::name, CKIRV2Graph::num_buffers, CKMemSpan::offset_bytes, CKMemSpan::size_bytes, CKMemPlan::spans, and CKMemPlan::total_bytes.

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_emit_resolved_shape()

static void ck_ir_v2_emit_resolved_shape ( FILE *  out,
const CKModelConfig cfg,
const CKIRV2AlignInfo *  align,
const CKDimToken shape,
int  tokens_override 
)
static

Definition at line 261 of file ckernel_ir_v2.c.

266 {
267  fprintf(out, "[");
268  int first = 1;
269  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
270  if (shape[i].dim == CK_DIM_END) {
271  break;
272  }
273  size_t dim = ck_ir_v2_resolve_dim_value(cfg, align, shape[i].dim, tokens_override);
274  size_t mult = (size_t)(shape[i].mult > 0 ? shape[i].mult : 1);
275  size_t div = (size_t)(shape[i].div > 0 ? shape[i].div : 1);
276  size_t resolved = (div == 0) ? 0 : (dim * mult / div);
277  if (!first) {
278  fprintf(out, ", ");
279  }
280  fprintf(out, "%zu", resolved);
281  first = 0;
282  }
283  fprintf(out, "]");
284 }
#define CK_IR_V2_MAX_DIMS
Definition: ckernel_ir_v2.h:13

References CK_DIM_END, CK_IR_V2_MAX_DIMS, and ck_ir_v2_resolve_dim_value().

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_emit_shape()

static int ck_ir_v2_emit_shape ( FILE *  out,
const CKDimToken shape 
)
static

Definition at line 239 of file ckernel_ir_v2.c.

240 {
241  fprintf(out, "[");
242  int first = 1;
243  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
244  if (shape[i].dim == CK_DIM_END) {
245  break;
246  }
247  if (!first) {
248  fprintf(out, ", ");
249  }
250  fprintf(out, "{\"dim\":\"%s\",\"dim_id\":%d,\"mult\":%d,\"div\":%d}",
251  ck_ir_v2_dim_name(shape[i].dim),
252  (int)shape[i].dim,
253  shape[i].mult,
254  shape[i].div);
255  first = 0;
256  }
257  fprintf(out, "]");
258  return 0;
259 }

References CK_DIM_END, ck_ir_v2_dim_name(), and CK_IR_V2_MAX_DIMS.

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_find_array_end()

static const char* ck_ir_v2_find_array_end ( const char *  open,
const char *  end 
)
static

Definition at line 610 of file ckernel_ir_v2.c.

611 {
612  if (!open || open >= end || *open != '[') {
613  return NULL;
614  }
615  int depth = 0;
616  for (const char *p = open; p < end; ++p) {
617  if (*p == '"') {
618  p = ck_ir_v2_skip_string(p, end) - 1;
619  continue;
620  }
621  if (*p == '[') {
622  depth++;
623  continue;
624  }
625  if (*p == ']') {
626  if (depth > 0) {
627  depth--;
628  if (depth == 0) {
629  return p;
630  }
631  }
632  }
633  }
634  return NULL;
635 }
static const char * ck_ir_v2_skip_string(const char *cur, const char *end)
uint32_t end
Definition: utf8.c:215

References ck_ir_v2_skip_string(), and end.

Referenced by ck_ir_v2_parse_bindings(), ck_ir_v2_parse_buffers(), ck_ir_v2_parse_nodes(), and ck_ir_v2_parse_shape().

◆ ck_ir_v2_find_buffer_index()

static int ck_ir_v2_find_buffer_index ( const CKIRV2Graph graph,
const char *  name 
)
static

Definition at line 868 of file ckernel_ir_v2.c.

869 {
870  if (!graph || !name) {
871  return -1;
872  }
873  for (int i = 0; i < graph->num_buffers; ++i) {
874  if (graph->buffers[i].name && strcmp(graph->buffers[i].name, name) == 0) {
875  return i;
876  }
877  }
878  return -1;
879 }

References CKIRV2Graph::buffers, CKIRV2Buffer::name, and CKIRV2Graph::num_buffers.

Referenced by ck_ir_v2_parse_bindings().

◆ ck_ir_v2_find_buffer_spec()

static const CKBufferSpec* ck_ir_v2_find_buffer_spec ( const char *  name)
static

Definition at line 855 of file ckernel_ir_v2.c.

856 {
857  if (!name) {
858  return NULL;
859  }
860  for (size_t i = 0; i < ck_decoder_buffer_count; ++i) {
861  if (strcmp(ck_decoder_buffers[i].name, name) == 0) {
862  return &ck_decoder_buffers[i];
863  }
864  }
865  return NULL;
866 }
const CKBufferSpec ck_decoder_buffers[]
const size_t ck_decoder_buffer_count

References ck_decoder_buffer_count, and ck_decoder_buffers.

Referenced by ck_ir_v2_parse_buffers().

◆ ck_ir_v2_find_key()

static const char* ck_ir_v2_find_key ( const char *  json,
const char *  key,
const char *  end 
)
static

Definition at line 637 of file ckernel_ir_v2.c.

640 {
641  size_t key_len = strlen(key);
642  const char *cur = json;
643  while (cur + key_len < end) {
644  if (memcmp(cur, key, key_len) == 0) {
645  return cur;
646  }
647  cur++;
648  }
649  return NULL;
650 }

References end.

Referenced by ck_ir_v2_parse_bindings(), ck_ir_v2_parse_bool(), ck_ir_v2_parse_buffers(), ck_ir_v2_parse_float(), ck_ir_v2_parse_int(), ck_ir_v2_parse_json(), ck_ir_v2_parse_nodes(), ck_ir_v2_parse_shape(), and ck_ir_v2_parse_string_field().

◆ ck_ir_v2_free()

void ck_ir_v2_free ( CKIRV2Graph graph)

Definition at line 34 of file ckernel_ir_v2.c.

35 {
36  if (!graph) {
37  return;
38  }
39  if (graph->buffers) {
40  for (int i = 0; i < graph->num_buffers; ++i) {
41  ck_ir_v2_free_buffer(&graph->buffers[i]);
42  }
43  free(graph->buffers);
44  }
45  if (graph->nodes) {
46  for (int i = 0; i < graph->num_nodes; ++i) {
47  ck_ir_v2_free_node(&graph->nodes[i]);
48  }
49  free(graph->nodes);
50  }
51  memset(graph, 0, sizeof(*graph));
52 }
static void ck_ir_v2_free_node(CKIRV2Node *node)
Definition: ckernel_ir_v2.c:20
static void ck_ir_v2_free_buffer(CKIRV2Buffer *buf)
Definition: ckernel_ir_v2.c:9
CKIRV2Node * nodes
Definition: ckernel_ir_v2.h:64

References CKIRV2Graph::buffers, ck_ir_v2_free_buffer(), ck_ir_v2_free_node(), CKIRV2Graph::nodes, CKIRV2Graph::num_buffers, and CKIRV2Graph::num_nodes.

Referenced by ck_ir_v2_build_decoder(), ck_ir_v2_build_decoder_backward(), ck_ir_v2_lower_emit_json(), ck_ir_v2_lower_graph(), ck_ir_v2_parse_json(), and main().

◆ ck_ir_v2_free_buffer()

static void ck_ir_v2_free_buffer ( CKIRV2Buffer buf)
static

Definition at line 9 of file ckernel_ir_v2.c.

10 {
11  if (!buf) {
12  return;
13  }
14  free(buf->name);
15  free(buf->alias_of);
16  free(buf->condition);
17  memset(buf, 0, sizeof(*buf));
18 }
char * condition
Definition: ckernel_ir_v2.h:32
char * alias_of
Definition: ckernel_ir_v2.h:31

References CKIRV2Buffer::alias_of, CKIRV2Buffer::condition, and CKIRV2Buffer::name.

Referenced by ck_ir_v2_free().

◆ ck_ir_v2_free_node()

static void ck_ir_v2_free_node ( CKIRV2Node node)
static

Definition at line 20 of file ckernel_ir_v2.c.

21 {
22  if (!node) {
23  return;
24  }
25  for (int i = 0; i < node->n_bindings; ++i) {
26  free(node->bindings[i].arg);
27  }
28  free(node->op);
29  free(node->kernel);
30  free(node->condition);
31  memset(node, 0, sizeof(*node));
32 }
uint8_t n_bindings
Definition: ckernel_ir_v2.h:48
char * condition
Definition: ckernel_ir_v2.h:44
CKIRV2Binding bindings[24]
Definition: ckernel_ir_v2.h:47
char * kernel
Definition: ckernel_ir_v2.h:42

References CKIRV2Binding::arg, CKIRV2Node::bindings, CKIRV2Node::condition, CKIRV2Node::kernel, CKIRV2Node::n_bindings, and CKIRV2Node::op.

Referenced by ck_ir_v2_free(), and ck_ir_v2_parse_nodes().

◆ ck_ir_v2_mem_arena_name()

static const char* ck_ir_v2_mem_arena_name ( CKMemArenaKind  arena)
static

Definition at line 319 of file ckernel_ir_v2.c.

320 {
321  switch (arena) {
323  return "weights";
325  return "activations";
326  case CK_MEM_ARENA_GRADS:
327  return "grads";
328  default:
329  return "unknown";
330  }
331 }

References CK_MEM_ARENA_ACTIVATIONS, CK_MEM_ARENA_GRADS, and CK_MEM_ARENA_WEIGHTS.

Referenced by ck_ir_v2_emit_memory_plan().

◆ ck_ir_v2_next_object()

static const char* ck_ir_v2_next_object ( const char *  cur,
const char *  end,
const char **  obj_start,
const char **  obj_end 
)
static

Definition at line 577 of file ckernel_ir_v2.c.

581 {
582  int depth = 0;
583  const char *start = NULL;
584  for (const char *p = cur; p < end; ++p) {
585  if (*p == '"') {
586  p = ck_ir_v2_skip_string(p, end) - 1;
587  continue;
588  }
589  if (*p == '{') {
590  if (depth == 0) {
591  start = p;
592  }
593  depth++;
594  continue;
595  }
596  if (*p == '}') {
597  if (depth > 0) {
598  depth--;
599  if (depth == 0 && start) {
600  *obj_start = start;
601  *obj_end = p;
602  return p + 1;
603  }
604  }
605  }
606  }
607  return NULL;
608 }
uint32_t start
Definition: utf8.c:214

References ck_ir_v2_skip_string(), end, and start.

Referenced by ck_ir_v2_parse_bindings(), ck_ir_v2_parse_buffers(), ck_ir_v2_parse_json(), ck_ir_v2_parse_nodes(), and ck_ir_v2_parse_shape().

◆ ck_ir_v2_parse_bindings()

static int ck_ir_v2_parse_bindings ( const char *  obj_start,
const char *  obj_end,
CKIRV2Graph graph,
CKIRV2Node node 
)
static

Definition at line 965 of file ckernel_ir_v2.c.

969 {
970  const char *start = ck_ir_v2_find_key(obj_start, "\"bindings\"", obj_end);
971  if (!start) {
972  return 0;
973  }
974  const char *open = strchr(start, '[');
975  const char *close = ck_ir_v2_find_array_end(open, obj_end);
976  if (!open || !close || close > obj_end) {
977  return -1;
978  }
979 
980  const char *cur = open;
981  const char *bstart = NULL;
982  const char *bend = NULL;
983  while (node->n_bindings < CK_IR_V2_MAX_BINDINGS &&
984  (cur = ck_ir_v2_next_object(cur, close, &bstart, &bend)) != NULL) {
985 
986  char *arg = NULL;
987  char *buf = NULL;
988  ck_ir_v2_parse_string_field(bstart, "\"arg\"", bend, &arg);
989  ck_ir_v2_parse_string_field(bstart, "\"buffer\"", bend, &buf);
990 
991  if (arg) {
992  int buf_idx = ck_ir_v2_find_buffer_index(graph, buf);
993  node->bindings[node->n_bindings].arg = arg;
994  node->bindings[node->n_bindings].buffer = buf_idx;
995  node->n_bindings++;
996  } else {
997  free(arg);
998  }
999  free(buf);
1000  }
1001 
1002  return 0;
1003 }
static const char * ck_ir_v2_next_object(const char *cur, const char *end, const char **obj_start, const char **obj_end)
static const char * ck_ir_v2_find_array_end(const char *open, const char *end)
static const char * ck_ir_v2_find_key(const char *json, const char *key, const char *end)
static int ck_ir_v2_parse_string_field(const char *json, const char *key, const char *end, char **out_str)
static int ck_ir_v2_find_buffer_index(const CKIRV2Graph *graph, const char *name)
#define CK_IR_V2_MAX_BINDINGS
Definition: ckernel_ir_v2.h:16
int32_t buffer
Definition: ckernel_ir_v2.h:37

References CKIRV2Binding::arg, CKIRV2Node::bindings, CKIRV2Binding::buffer, ck_ir_v2_find_array_end(), ck_ir_v2_find_buffer_index(), ck_ir_v2_find_key(), CK_IR_V2_MAX_BINDINGS, ck_ir_v2_next_object(), ck_ir_v2_parse_string_field(), CKIRV2Node::n_bindings, and start.

Referenced by ck_ir_v2_parse_nodes().

◆ ck_ir_v2_parse_bool()

static int ck_ir_v2_parse_bool ( const char *  json,
const char *  key,
const char *  end,
int *  out_val 
)
static

Definition at line 673 of file ckernel_ir_v2.c.

677 {
678  const char *p = ck_ir_v2_find_key(json, key, end);
679  if (!p) {
680  return -1;
681  }
682  const char *colon = strchr(p, ':');
683  if (!colon || colon >= end) {
684  return -1;
685  }
686  const char *cur = colon + 1;
687  while (cur < end && (*cur == ' ' || *cur == '\t' || *cur == '\n' || *cur == '\r')) {
688  cur++;
689  }
690  if (cur + 4 <= end && memcmp(cur, "true", 4) == 0) {
691  *out_val = 1;
692  return 0;
693  }
694  if (cur + 5 <= end && memcmp(cur, "false", 5) == 0) {
695  *out_val = 0;
696  return 0;
697  }
698  return -1;
699 }

References ck_ir_v2_find_key(), and end.

Referenced by ck_ir_v2_parse_json().

◆ ck_ir_v2_parse_buffers()

static int ck_ir_v2_parse_buffers ( const char *  json,
const char *  end,
CKIRV2Graph graph 
)
static

Definition at line 881 of file ckernel_ir_v2.c.

884 {
885  const char *start = ck_ir_v2_find_key(json, "\"buffers\"", end);
886  if (!start) {
887  return -1;
888  }
889  const char *open = strchr(start, '[');
890  const char *close = ck_ir_v2_find_array_end(open, end);
891  if (!open || !close) {
892  return -1;
893  }
894 
895  int count = 0;
896  const char *cur = open;
897  const char *obj_start = NULL;
898  const char *obj_end = NULL;
899  while ((cur = ck_ir_v2_next_object(cur, close, &obj_start, &obj_end)) != NULL) {
900  count++;
901  }
902  if (count <= 0) {
903  return -1;
904  }
905 
906  CKIRV2Buffer *buffers = (CKIRV2Buffer *)calloc((size_t)count, sizeof(CKIRV2Buffer));
907  if (!buffers) {
908  return -1;
909  }
910  int idx = 0;
911  cur = open;
912  while (idx < count &&
913  (cur = ck_ir_v2_next_object(cur, close, &obj_start, &obj_end)) != NULL) {
914 
915  CKIRV2Buffer buf = {0};
916  char *name = NULL;
917  char *scope = NULL;
918  char *role = NULL;
919  char *dtype = NULL;
920  char *alias = NULL;
921  char *cond = NULL;
922  int optional = 0;
923 
924  ck_ir_v2_parse_string_field(obj_start, "\"name\"", obj_end, &name);
925  ck_ir_v2_parse_string_field(obj_start, "\"scope\"", obj_end, &scope);
926  ck_ir_v2_parse_string_field(obj_start, "\"role\"", obj_end, &role);
927  ck_ir_v2_parse_string_field(obj_start, "\"dtype\"", obj_end, &dtype);
928  ck_ir_v2_parse_string_field(obj_start, "\"alias_of\"", obj_end, &alias);
929  ck_ir_v2_parse_string_field(obj_start, "\"condition\"", obj_end, &cond);
930  ck_ir_v2_parse_int(obj_start, "\"optional\"", obj_end, &optional);
931 
932  buf.name = name;
933  buf.scope = ck_ir_v2_parse_scope(scope);
934  buf.role = ck_ir_v2_parse_role(role);
935  buf.dtype = ck_ir_v2_parse_dtype(dtype);
936  buf.optional = optional;
937  buf.alias_of = alias;
938  buf.condition = cond;
939 
940  if (ck_ir_v2_parse_shape(obj_start, obj_end, buf.shape) != 0) {
941  const CKBufferSpec *spec = ck_ir_v2_find_buffer_spec(name);
942  if (spec) {
943  memcpy(buf.shape, spec->shape, sizeof(buf.shape));
944  } else {
945  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
946  buf.shape[i].dim = CK_DIM_END;
947  buf.shape[i].mult = 0;
948  buf.shape[i].div = 0;
949  }
950  }
951  }
952 
953  free(scope);
954  free(role);
955  free(dtype);
956 
957  buffers[idx++] = buf;
958  }
959 
960  graph->buffers = buffers;
961  graph->num_buffers = idx;
962  return 0;
963 }
static int ck_ir_v2_parse_shape(const char *obj_start, const char *obj_end, CKDimToken *shape_out)
static const CKBufferSpec * ck_ir_v2_find_buffer_spec(const char *name)
static CKBufferRole ck_ir_v2_parse_role(const char *s)
static int ck_ir_v2_parse_int(const char *json, const char *key, const char *end, int *out_val)
static CKDataType ck_ir_v2_parse_dtype(const char *s)
static CKBufferScope ck_ir_v2_parse_scope(const char *s)
CKDimToken shape[4]
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
CKBufferRole role
Definition: ckernel_ir_v2.h:27
CKBufferScope scope
Definition: ckernel_ir_v2.h:26
CKDataType dtype
Definition: ckernel_ir_v2.h:28

References CKIRV2Buffer::alias_of, CKIRV2Graph::buffers, CK_DIM_END, ck_ir_v2_find_array_end(), ck_ir_v2_find_buffer_spec(), ck_ir_v2_find_key(), CK_IR_V2_MAX_DIMS, ck_ir_v2_next_object(), ck_ir_v2_parse_dtype(), ck_ir_v2_parse_int(), ck_ir_v2_parse_role(), ck_ir_v2_parse_scope(), ck_ir_v2_parse_shape(), ck_ir_v2_parse_string_field(), CKIRV2Buffer::condition, CKDimToken::dim, CKDimToken::div, CKIRV2Buffer::dtype, end, CKDimToken::mult, CKIRV2Buffer::name, CKIRV2Graph::num_buffers, CKIRV2Buffer::optional, CKIRV2Buffer::role, CKIRV2Buffer::scope, CKIRV2Buffer::shape, CKBufferSpec::shape, and start.

Referenced by ck_ir_v2_parse_json().

◆ ck_ir_v2_parse_dim_kind()

static CKDimKind ck_ir_v2_parse_dim_kind ( const char *  obj_start,
const char *  obj_end 
)
static

Definition at line 785 of file ckernel_ir_v2.c.

787 {
788  char *dim_str = NULL;
789  CKDimKind kind = CK_DIM_END;
790  if (ck_ir_v2_parse_string_field(obj_start, "\"dim\"", obj_end, &dim_str) == 0 &&
791  dim_str) {
792  kind = ck_ir_v2_dim_kind_from_name(dim_str);
793  }
794  if (dim_str) {
795  free(dim_str);
796  }
797  if (kind != CK_DIM_END) {
798  return kind;
799  }
800  int dim_id = -1;
801  if (ck_ir_v2_parse_int(obj_start, "\"dim_id\"", obj_end, &dim_id) == 0) {
802  return (CKDimKind)dim_id;
803  }
804  if (ck_ir_v2_parse_int(obj_start, "\"dim\"", obj_end, &dim_id) == 0) {
805  return (CKDimKind)dim_id;
806  }
807  return CK_DIM_END;
808 }
static CKDimKind ck_ir_v2_dim_kind_from_name(const char *name)

References CK_DIM_END, ck_ir_v2_dim_kind_from_name(), ck_ir_v2_parse_int(), and ck_ir_v2_parse_string_field().

Referenced by ck_ir_v2_parse_shape().

◆ ck_ir_v2_parse_dtype()

static CKDataType ck_ir_v2_parse_dtype ( const char *  s)
static

Definition at line 772 of file ckernel_ir_v2.c.

773 {
774  if (!s) return CK_DT_FP32;
775  if (strcmp(s, "fp32") == 0) return CK_DT_FP32;
776  if (strcmp(s, "bf16") == 0) return CK_DT_BF16;
777  if (strcmp(s, "fp16") == 0) return CK_DT_FP16;
778  if (strcmp(s, "q4_0") == 0) return CK_DT_Q4_0;
779  if (strcmp(s, "q4_k") == 0) return CK_DT_Q4_K;
780  if (strcmp(s, "q6_k") == 0) return CK_DT_Q6_K;
781  if (strcmp(s, "q8_0") == 0) return CK_DT_Q8_0;
782  return CK_DT_FP32;
783 }

References CK_DT_BF16, CK_DT_FP16, CK_DT_FP32, CK_DT_Q4_0, CK_DT_Q4_K, CK_DT_Q6_K, and CK_DT_Q8_0.

Referenced by ck_ir_v2_parse_buffers(), and ck_ir_v2_parse_nodes().

◆ ck_ir_v2_parse_float()

static int ck_ir_v2_parse_float ( const char *  json,
const char *  key,
const char *  end,
float *  out_val 
)
static

Definition at line 701 of file ckernel_ir_v2.c.

705 {
706  const char *p = ck_ir_v2_find_key(json, key, end);
707  if (!p) {
708  return -1;
709  }
710  const char *colon = strchr(p, ':');
711  if (!colon || colon >= end) {
712  return -1;
713  }
714  float value = 0.0f;
715  if (sscanf(colon + 1, "%f", &value) != 1) {
716  return -1;
717  }
718  *out_val = value;
719  return 0;
720 }

References ck_ir_v2_find_key(), and end.

Referenced by ck_ir_v2_parse_json().

◆ ck_ir_v2_parse_int()

static int ck_ir_v2_parse_int ( const char *  json,
const char *  key,
const char *  end,
int *  out_val 
)
static

Definition at line 652 of file ckernel_ir_v2.c.

656 {
657  const char *p = ck_ir_v2_find_key(json, key, end);
658  if (!p) {
659  return -1;
660  }
661  const char *colon = strchr(p, ':');
662  if (!colon || colon >= end) {
663  return -1;
664  }
665  int value = 0;
666  if (sscanf(colon + 1, "%d", &value) != 1) {
667  return -1;
668  }
669  *out_val = value;
670  return 0;
671 }

References ck_ir_v2_find_key(), and end.

Referenced by ck_ir_v2_parse_buffers(), ck_ir_v2_parse_dim_kind(), ck_ir_v2_parse_json(), ck_ir_v2_parse_nodes(), and ck_ir_v2_parse_shape().

◆ ck_ir_v2_parse_json()

int ck_ir_v2_parse_json ( const char *  path,
CKIRV2Graph graph 
)

Definition at line 1100 of file ckernel_ir_v2.c.

1101 {
1102  if (!path || !graph) {
1103  return -1;
1104  }
1105  FILE *f = fopen(path, "rb");
1106  if (!f) {
1107  perror("ck_ir_v2_parse_json: fopen");
1108  return -1;
1109  }
1110  if (fseek(f, 0, SEEK_END) != 0) {
1111  fclose(f);
1112  return -1;
1113  }
1114  long len = ftell(f);
1115  if (len < 0) {
1116  fclose(f);
1117  return -1;
1118  }
1119  if (fseek(f, 0, SEEK_SET) != 0) {
1120  fclose(f);
1121  return -1;
1122  }
1123  char *buf = (char *)malloc((size_t)len + 1);
1124  if (!buf) {
1125  fclose(f);
1126  return -1;
1127  }
1128  size_t nread = fread(buf, 1, (size_t)len, f);
1129  fclose(f);
1130  buf[nread] = '\0';
1131 
1132  CKIRV2Graph tmp = {0};
1133  const char *end = buf + nread;
1134 
1135  tmp.has_pos_emb = 1;
1136  tmp.tie_word_embeddings = -1;
1137  tmp.fused_qkv = -1;
1138  tmp.gated_mlp = -1;
1139 
1140  if (ck_ir_v2_parse_int(buf, "\"num_layers\"", end, &tmp.config.num_layers) != 0 ||
1141  ck_ir_v2_parse_int(buf, "\"hidden_size\"", end, &tmp.config.hidden_size) != 0 ||
1142  ck_ir_v2_parse_int(buf, "\"intermediate_size\"", end, &tmp.config.intermediate_size) != 0 ||
1143  ck_ir_v2_parse_int(buf, "\"num_attention_heads\"", end, &tmp.config.num_heads) != 0 ||
1144  ck_ir_v2_parse_int(buf, "\"num_kv_heads\"", end, &tmp.config.num_kv_heads) != 0) {
1145  free(buf);
1146  return -1;
1147  }
1148 
1149  ck_ir_v2_parse_int(buf, "\"vocab_size\"", end, &tmp.config.vocab_size);
1150  ck_ir_v2_parse_int(buf, "\"context_window\"", end, &tmp.config.context_window);
1151  ck_ir_v2_parse_float(buf, "\"rms_norm_eps\"", end, &tmp.config.rms_norm_eps);
1152  ck_ir_v2_parse_float(buf, "\"rope_theta\"", end, &tmp.config.rope_theta);
1153  const char *meta_key = ck_ir_v2_find_key(buf, "\"meta\"", end);
1154  if (meta_key) {
1155  const char *brace = strchr(meta_key, '{');
1156  const char *obj_start = NULL;
1157  const char *obj_end = NULL;
1158  if (brace && ck_ir_v2_next_object(brace, end, &obj_start, &obj_end)) {
1159  ck_ir_v2_parse_bool(obj_start, "\"has_pos_emb\"", obj_end, &tmp.has_pos_emb);
1160  ck_ir_v2_parse_bool(obj_start, "\"tie_word_embeddings\"", obj_end, &tmp.tie_word_embeddings);
1161  ck_ir_v2_parse_bool(obj_start, "\"fused_qkv\"", obj_end, &tmp.fused_qkv);
1162  ck_ir_v2_parse_bool(obj_start, "\"gated_mlp\"", obj_end, &tmp.gated_mlp);
1163  }
1164  }
1165 
1166  if (ck_ir_v2_parse_buffers(buf, end, &tmp) != 0 ||
1167  ck_ir_v2_parse_nodes(buf, end, &tmp) != 0) {
1168  ck_ir_v2_free(&tmp);
1169  free(buf);
1170  return -1;
1171  }
1172 
1173  *graph = tmp;
1174  free(buf);
1175  return 0;
1176 }
static int ck_ir_v2_parse_buffers(const char *json, const char *end, CKIRV2Graph *graph)
void ck_ir_v2_free(CKIRV2Graph *graph)
Definition: ckernel_ir_v2.c:34
static int ck_ir_v2_parse_bool(const char *json, const char *key, const char *end, int *out_val)
static int ck_ir_v2_parse_nodes(const char *json, const char *end, CKIRV2Graph *graph)
static int ck_ir_v2_parse_float(const char *json, const char *key, const char *end, float *out_val)
CKModelConfig config
Definition: ckernel_ir_v2.h:56
int tie_word_embeddings
Definition: ckernel_ir_v2.h:58
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37
float rms_norm_eps
Definition: ckernel_ir.h:31
float rope_theta
Definition: ckernel_ir.h:32

References ck_ir_v2_find_key(), ck_ir_v2_free(), ck_ir_v2_next_object(), ck_ir_v2_parse_bool(), ck_ir_v2_parse_buffers(), ck_ir_v2_parse_float(), ck_ir_v2_parse_int(), ck_ir_v2_parse_nodes(), CKIRV2Graph::config, CKModelConfig::context_window, end, CKIRV2Graph::fused_qkv, CKIRV2Graph::gated_mlp, CKIRV2Graph::has_pos_emb, CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKModelConfig::rms_norm_eps, CKModelConfig::rope_theta, CKIRV2Graph::tie_word_embeddings, and CKModelConfig::vocab_size.

Referenced by main().

◆ ck_ir_v2_parse_nodes()

static int ck_ir_v2_parse_nodes ( const char *  json,
const char *  end,
CKIRV2Graph graph 
)
static

Definition at line 1005 of file ckernel_ir_v2.c.

1008 {
1009  const char *start = ck_ir_v2_find_key(json, "\"nodes\"", end);
1010  if (!start) {
1011  return -1;
1012  }
1013  const char *open = strchr(start, '[');
1014  const char *close = ck_ir_v2_find_array_end(open, end);
1015  if (!open || !close) {
1016  return -1;
1017  }
1018 
1019  int count = 0;
1020  const char *cur = open;
1021  const char *obj_start = NULL;
1022  const char *obj_end = NULL;
1023  while ((cur = ck_ir_v2_next_object(cur, close, &obj_start, &obj_end)) != NULL) {
1024  count++;
1025  }
1026  if (count <= 0) {
1027  return -1;
1028  }
1029 
1030  CKIRV2Node *nodes = (CKIRV2Node *)calloc((size_t)count, sizeof(CKIRV2Node));
1031  if (!nodes) {
1032  return -1;
1033  }
1034  int idx = 0;
1035  cur = open;
1036  while (idx < count &&
1037  (cur = ck_ir_v2_next_object(cur, close, &obj_start, &obj_end)) != NULL) {
1038 
1039  CKIRV2Node node = {0};
1040  char *op = NULL;
1041  char *kernel = NULL;
1042  char *kernel_variant = NULL;
1043  char *kernel_dtype = NULL;
1044  char *cond = NULL;
1045  int layer = 0;
1046  int flags = 0;
1047 
1048  ck_ir_v2_parse_string_field(obj_start, "\"op\"", obj_end, &op);
1049  ck_ir_v2_parse_string_field(obj_start, "\"kernel\"", obj_end, &kernel);
1050  ck_ir_v2_parse_string_field(obj_start, "\"kernel_variant\"", obj_end, &kernel_variant);
1051  ck_ir_v2_parse_string_field(obj_start, "\"kernel_dtype\"", obj_end, &kernel_dtype);
1052  ck_ir_v2_parse_string_field(obj_start, "\"condition\"", obj_end, &cond);
1053  ck_ir_v2_parse_int(obj_start, "\"layer\"", obj_end, &layer);
1054  ck_ir_v2_parse_int(obj_start, "\"flags\"", obj_end, &flags);
1055 
1056  if (!kernel && kernel_variant) {
1057  kernel = kernel_variant;
1058  kernel_variant = NULL;
1059  }
1060 
1061  node.op = op;
1062  node.kernel = kernel;
1063  node.kernel_dtype = ck_ir_v2_parse_dtype(kernel_dtype);
1064  node.condition = cond;
1065  node.layer = (uint16_t)layer;
1066  node.flags = (uint8_t)flags;
1067  node.n_bindings = 0;
1068  node.n_inputs = 0;
1069  node.n_outputs = 0;
1070 
1071  if (kernel_variant) {
1072  free(kernel_variant);
1073  }
1074  if (kernel_dtype) {
1075  free(kernel_dtype);
1076  }
1077 
1078  if (ck_ir_v2_parse_bindings(obj_start, obj_end, graph, &node) != 0) {
1079  free(node.op);
1080  free(node.kernel);
1081  free(node.condition);
1082  for (int b = 0; b < node.n_bindings; ++b) {
1083  free(node.bindings[b].arg);
1084  }
1085  for (int j = 0; j < idx; ++j) {
1086  ck_ir_v2_free_node(&nodes[j]);
1087  }
1088  free(nodes);
1089  return -1;
1090  }
1091 
1092  nodes[idx++] = node;
1093  }
1094 
1095  graph->nodes = nodes;
1096  graph->num_nodes = idx;
1097  return 0;
1098 }
static int ck_ir_v2_parse_bindings(const char *obj_start, const char *obj_end, CKIRV2Graph *graph, CKIRV2Node *node)
uint16_t layer
Definition: ckernel_ir_v2.h:45
uint8_t flags
Definition: ckernel_ir_v2.h:46
CKDataType kernel_dtype
Definition: ckernel_ir_v2.h:43
uint8_t n_outputs
Definition: ckernel_ir_v2.h:52
uint8_t n_inputs
Definition: ckernel_ir_v2.h:50

References CKIRV2Binding::arg, CKIRV2Node::bindings, ck_ir_v2_find_array_end(), ck_ir_v2_find_key(), ck_ir_v2_free_node(), ck_ir_v2_next_object(), ck_ir_v2_parse_bindings(), ck_ir_v2_parse_dtype(), ck_ir_v2_parse_int(), ck_ir_v2_parse_string_field(), CKIRV2Node::condition, end, CKIRV2Node::flags, CKIRV2Node::kernel, CKIRV2Node::kernel_dtype, CKIRV2Node::layer, CKIRV2Node::n_bindings, CKIRV2Node::n_inputs, CKIRV2Node::n_outputs, CKIRV2Graph::nodes, CKIRV2Graph::num_nodes, CKIRV2Node::op, and start.

Referenced by ck_ir_v2_parse_json().

◆ ck_ir_v2_parse_role()

static CKBufferRole ck_ir_v2_parse_role ( const char *  s)
static

Definition at line 760 of file ckernel_ir_v2.c.

761 {
762  if (!s) return CK_ROLE_ACTIVATION;
763  if (strcmp(s, "input") == 0) return CK_ROLE_INPUT;
764  if (strcmp(s, "output") == 0) return CK_ROLE_OUTPUT;
765  if (strcmp(s, "activation") == 0) return CK_ROLE_ACTIVATION;
766  if (strcmp(s, "weight") == 0) return CK_ROLE_WEIGHT;
767  if (strcmp(s, "scratch") == 0) return CK_ROLE_SCRATCH;
768  if (strcmp(s, "grad") == 0) return CK_ROLE_GRAD;
769  return CK_ROLE_ACTIVATION;
770 }
@ CK_ROLE_WEIGHT
@ CK_ROLE_SCRATCH
@ CK_ROLE_GRAD
@ CK_ROLE_ACTIVATION
@ CK_ROLE_INPUT
@ CK_ROLE_OUTPUT

References CK_ROLE_ACTIVATION, CK_ROLE_GRAD, CK_ROLE_INPUT, CK_ROLE_OUTPUT, CK_ROLE_SCRATCH, and CK_ROLE_WEIGHT.

Referenced by ck_ir_v2_parse_buffers().

◆ ck_ir_v2_parse_scope()

static CKBufferScope ck_ir_v2_parse_scope ( const char *  s)
static

Definition at line 752 of file ckernel_ir_v2.c.

753 {
754  if (!s) return CK_SCOPE_GLOBAL;
755  if (strcmp(s, "layer") == 0) return CK_SCOPE_LAYER;
756  if (strcmp(s, "global") == 0) return CK_SCOPE_GLOBAL;
757  return CK_SCOPE_GLOBAL;
758 }
@ CK_SCOPE_LAYER
@ CK_SCOPE_GLOBAL

References CK_SCOPE_GLOBAL, and CK_SCOPE_LAYER.

Referenced by ck_ir_v2_parse_buffers().

◆ ck_ir_v2_parse_shape()

static int ck_ir_v2_parse_shape ( const char *  obj_start,
const char *  obj_end,
CKDimToken shape_out 
)
static

Definition at line 810 of file ckernel_ir_v2.c.

813 {
814  if (!shape_out) {
815  return -1;
816  }
817  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
818  shape_out[i].dim = CK_DIM_END;
819  shape_out[i].mult = 0;
820  shape_out[i].div = 0;
821  }
822 
823  const char *start = ck_ir_v2_find_key(obj_start, "\"shape\"", obj_end);
824  if (!start) {
825  return -1;
826  }
827  const char *open = strchr(start, '[');
828  const char *close = ck_ir_v2_find_array_end(open, obj_end);
829  if (!open || !close) {
830  return -1;
831  }
832 
833  const char *cur = open;
834  const char *sstart = NULL;
835  const char *send = NULL;
836  int idx = 0;
837  while (idx < CK_IR_V2_MAX_DIMS &&
838  (cur = ck_ir_v2_next_object(cur, close, &sstart, &send)) != NULL) {
839  CKDimKind dim = ck_ir_v2_parse_dim_kind(sstart, send);
840  if (dim == CK_DIM_END) {
841  continue;
842  }
843  int mult = 1;
844  int div = 1;
845  ck_ir_v2_parse_int(sstart, "\"mult\"", send, &mult);
846  ck_ir_v2_parse_int(sstart, "\"div\"", send, &div);
847  shape_out[idx].dim = dim;
848  shape_out[idx].mult = mult;
849  shape_out[idx].div = div;
850  idx++;
851  }
852  return (idx > 0) ? 0 : -1;
853 }
static CKDimKind ck_ir_v2_parse_dim_kind(const char *obj_start, const char *obj_end)

References CK_DIM_END, ck_ir_v2_find_array_end(), ck_ir_v2_find_key(), CK_IR_V2_MAX_DIMS, ck_ir_v2_next_object(), ck_ir_v2_parse_dim_kind(), ck_ir_v2_parse_int(), CKDimToken::dim, CKDimToken::div, CKDimToken::mult, and start.

Referenced by ck_ir_v2_parse_buffers().

◆ ck_ir_v2_parse_string()

static int ck_ir_v2_parse_string ( const char *  start,
const char *  end,
char **  out_str 
)
static

Definition at line 529 of file ckernel_ir_v2.c.

532 {
533  if (!start || !out_str || start >= end || *start != '"') {
534  return -1;
535  }
536  const char *cur = start + 1;
537  while (cur < end && *cur != '"') {
538  if (*cur == '\\' && (cur + 1) < end) {
539  cur += 2;
540  continue;
541  }
542  cur++;
543  }
544  if (cur >= end || *cur != '"') {
545  return -1;
546  }
547  size_t len = (size_t)(cur - (start + 1));
548  char *buf = (char *)malloc(len + 1);
549  if (!buf) {
550  return -1;
551  }
552  memcpy(buf, start + 1, len);
553  buf[len] = '\0';
554  *out_str = buf;
555  return 0;
556 }

References end, and start.

Referenced by ck_ir_v2_parse_string_field().

◆ ck_ir_v2_parse_string_field()

static int ck_ir_v2_parse_string_field ( const char *  json,
const char *  key,
const char *  end,
char **  out_str 
)
static

Definition at line 722 of file ckernel_ir_v2.c.

726 {
727  const char *p = ck_ir_v2_find_key(json, key, end);
728  if (!p) {
729  return -1;
730  }
731  const char *colon = strchr(p, ':');
732  if (!colon || colon >= end) {
733  return -1;
734  }
735  const char *cur = colon + 1;
736  while (cur < end && (*cur == ' ' || *cur == '\t' || *cur == '\n' || *cur == '\r')) {
737  cur++;
738  }
739  if (cur >= end) {
740  return -1;
741  }
742  if (strncmp(cur, "null", 4) == 0) {
743  *out_str = NULL;
744  return 0;
745  }
746  if (*cur != '"') {
747  return -1;
748  }
749  return ck_ir_v2_parse_string(cur, end, out_str);
750 }
static int ck_ir_v2_parse_string(const char *start, const char *end, char **out_str)

References ck_ir_v2_find_key(), ck_ir_v2_parse_string(), and end.

Referenced by ck_ir_v2_parse_bindings(), ck_ir_v2_parse_buffers(), ck_ir_v2_parse_dim_kind(), and ck_ir_v2_parse_nodes().

◆ ck_ir_v2_resolve_align()

static void ck_ir_v2_resolve_align ( const CKModelConfig cfg,
size_t  alignment_bytes,
CKIRV2AlignInfo *  align 
)
static

Definition at line 179 of file ckernel_ir_v2.c.

182 {
183  if (!align) {
184  return;
185  }
186  memset(align, 0, sizeof(*align));
187  if (!cfg) {
188  return;
189  }
190  if (alignment_bytes == 0) {
191  alignment_bytes = CK_MEM_PLAN_DEFAULT_ALIGN;
192  }
193  size_t elem_bytes = sizeof(float);
194  size_t head_dim = (cfg->num_heads > 0) ? (size_t)(cfg->hidden_size / cfg->num_heads) : 0;
195  align->aligned_embed = ck_ir_v2_align_up_elems((size_t)cfg->hidden_size, elem_bytes, alignment_bytes);
196  align->aligned_head = ck_ir_v2_align_up_elems(head_dim, elem_bytes, alignment_bytes);
197  align->aligned_intermediate = ck_ir_v2_align_up_elems((size_t)cfg->intermediate_size,
198  elem_bytes, alignment_bytes);
199  align->aligned_context = ck_ir_v2_align_up_elems((size_t)cfg->context_window, elem_bytes, alignment_bytes);
200 }
static size_t ck_ir_v2_align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
#define CK_MEM_PLAN_DEFAULT_ALIGN

References ck_ir_v2_align_up_elems(), CK_MEM_PLAN_DEFAULT_ALIGN, CKModelConfig::context_window, CKModelConfig::hidden_size, CKModelConfig::intermediate_size, and CKModelConfig::num_heads.

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_resolve_dim_value()

static size_t ck_ir_v2_resolve_dim_value ( const CKModelConfig cfg,
const CKIRV2AlignInfo *  align,
CKDimKind  dim,
int  tokens_override 
)
static

Definition at line 202 of file ckernel_ir_v2.c.

206 {
207  switch (dim) {
208  case CK_DIM_TOKENS:
209  if (tokens_override >= 0) {
210  return (size_t)tokens_override;
211  }
212  return cfg ? (size_t)cfg->context_window : 0;
213  case CK_DIM_EMBED:
214  return cfg ? (size_t)cfg->hidden_size : 0;
216  return align ? align->aligned_embed : 0;
217  case CK_DIM_HEAD_DIM:
218  return (cfg && cfg->num_heads > 0) ? (size_t)(cfg->hidden_size / cfg->num_heads) : 0;
219  case CK_DIM_ALIGNED_HEAD:
220  return align ? align->aligned_head : 0;
221  case CK_DIM_NUM_HEADS:
222  return cfg ? (size_t)cfg->num_heads : 0;
223  case CK_DIM_NUM_KV_HEADS:
224  return cfg ? (size_t)cfg->num_kv_heads : 0;
225  case CK_DIM_ALIGNED_CTX:
226  return align ? align->aligned_context : 0;
227  case CK_DIM_INTERMEDIATE:
228  return cfg ? (size_t)cfg->intermediate_size : 0;
230  return align ? align->aligned_intermediate : 0;
231  case CK_DIM_VOCAB:
232  return cfg ? (size_t)cfg->vocab_size : 0;
233  case CK_DIM_END:
234  default:
235  return 0;
236  }
237 }

References CK_DIM_ALIGNED_CTX, CK_DIM_ALIGNED_EMBED, CK_DIM_ALIGNED_HEAD, CK_DIM_ALIGNED_INTERMEDIATE, CK_DIM_EMBED, CK_DIM_END, CK_DIM_HEAD_DIM, CK_DIM_INTERMEDIATE, CK_DIM_NUM_HEADS, CK_DIM_NUM_KV_HEADS, CK_DIM_TOKENS, CK_DIM_VOCAB, CKModelConfig::context_window, CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, and CKModelConfig::vocab_size.

Referenced by ck_ir_v2_emit_dimensions(), and ck_ir_v2_emit_resolved_shape().

◆ ck_ir_v2_role_name()

static const char* ck_ir_v2_role_name ( CKBufferRole  role)
static

Definition at line 66 of file ckernel_ir_v2.c.

67 {
68  switch (role) {
69  case CK_ROLE_INPUT:
70  return "input";
71  case CK_ROLE_OUTPUT:
72  return "output";
73  case CK_ROLE_ACTIVATION:
74  return "activation";
75  case CK_ROLE_WEIGHT:
76  return "weight";
77  case CK_ROLE_SCRATCH:
78  return "scratch";
79  case CK_ROLE_GRAD:
80  return "grad";
81  default:
82  return "unknown";
83  }
84 }

References CK_ROLE_ACTIVATION, CK_ROLE_GRAD, CK_ROLE_INPUT, CK_ROLE_OUTPUT, CK_ROLE_SCRATCH, and CK_ROLE_WEIGHT.

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_scope_name()

static const char* ck_ir_v2_scope_name ( CKBufferScope  scope)
static

Definition at line 54 of file ckernel_ir_v2.c.

55 {
56  switch (scope) {
57  case CK_SCOPE_LAYER:
58  return "layer";
59  case CK_SCOPE_GLOBAL:
60  return "global";
61  default:
62  return "unknown";
63  }
64 }

References CK_SCOPE_GLOBAL, and CK_SCOPE_LAYER.

Referenced by ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_serialize_json()

int ck_ir_v2_serialize_json ( const CKIRV2Graph graph,
const char *  path 
)

Definition at line 511 of file ckernel_ir_v2.c.

512 {
513  return ck_ir_v2_serialize_json_internal(graph, NULL, NULL, -1, -1, path);
514 }
static int ck_ir_v2_serialize_json_internal(const CKIRV2Graph *graph, const CKMemPlan *plan, const char *mode, int tokens_override, int base_context_window, const char *path)

References ck_ir_v2_serialize_json_internal().

Referenced by main().

◆ ck_ir_v2_serialize_json_internal()

static int ck_ir_v2_serialize_json_internal ( const CKIRV2Graph graph,
const CKMemPlan plan,
const char *  mode,
int  tokens_override,
int  base_context_window,
const char *  path 
)
static

Definition at line 365 of file ckernel_ir_v2.c.

371 {
372  if (!graph || !path) {
373  return -1;
374  }
375  FILE *out = fopen(path, "wb");
376  if (!out) {
377  fprintf(stderr, "ck_ir_v2_serialize_json: failed to open %s: %s\n",
378  path, strerror(errno));
379  return -1;
380  }
381 
382  CKIRV2AlignInfo align = {0};
384 
385  fprintf(out, "{\n");
386  fprintf(out, " \"version\": 2,\n");
387  fprintf(out, " \"notes\": [\n");
388  fprintf(out, " \"shape.dim uses symbolic names; see dimensions for resolved values\",\n");
389  fprintf(out,
390  " \"resolved_shape applies mult/div using alignment_bytes=%d and elem_bytes=4\",\n",
392  fprintf(out, " \"kernel is the selected impl; kernel_dtype records dtype selection\"\n");
393  fprintf(out, " ],\n");
394  fprintf(out, " \"config\": {\n");
395  fprintf(out, " \"num_layers\": %d,\n", graph->config.num_layers);
396  fprintf(out, " \"hidden_size\": %d,\n", graph->config.hidden_size);
397  fprintf(out, " \"intermediate_size\": %d,\n", graph->config.intermediate_size);
398  fprintf(out, " \"num_attention_heads\": %d,\n", graph->config.num_heads);
399  fprintf(out, " \"num_kv_heads\": %d,\n", graph->config.num_kv_heads);
400  fprintf(out, " \"vocab_size\": %d,\n", graph->config.vocab_size);
401  fprintf(out, " \"context_window\": %d,\n", graph->config.context_window);
402  fprintf(out, " \"rms_norm_eps\": %.9g,\n", graph->config.rms_norm_eps);
403  fprintf(out, " \"rope_theta\": %.9g\n", graph->config.rope_theta);
404  fprintf(out, " },\n");
405 
406  ck_ir_v2_emit_dimensions(out, &graph->config, &align, tokens_override);
407 
408  fprintf(out, " \"meta\": {\n");
409  fprintf(out, " \"has_pos_emb\": %s,\n", graph->has_pos_emb ? "true" : "false");
410  if (graph->tie_word_embeddings < 0) {
411  fprintf(out, " \"tie_word_embeddings\": null,\n");
412  } else {
413  fprintf(out, " \"tie_word_embeddings\": %s,\n", graph->tie_word_embeddings ? "true" : "false");
414  }
415  if (graph->fused_qkv < 0) {
416  fprintf(out, " \"fused_qkv\": null,\n");
417  } else {
418  fprintf(out, " \"fused_qkv\": %s,\n", graph->fused_qkv ? "true" : "false");
419  }
420  if (graph->gated_mlp < 0) {
421  fprintf(out, " \"gated_mlp\": null\n");
422  } else {
423  fprintf(out, " \"gated_mlp\": %s\n", graph->gated_mlp ? "true" : "false");
424  }
425  fprintf(out, " },\n");
426 
427  if (plan) {
428  int training = (mode && strcmp(mode, "backward") == 0);
429  int tokens = (tokens_override >= 0) ? tokens_override : graph->config.context_window;
430  fprintf(out, " \"lowering\": {\n");
431  fprintf(out, " \"mode\": \"%s\",\n", mode ? mode : "unknown");
432  fprintf(out, " \"training\": %s,\n", training ? "true" : "false");
433  fprintf(out, " \"tokens\": %d", tokens);
434  if (base_context_window >= 0) {
435  fprintf(out, ",\n \"base_context_window\": %d\n", base_context_window);
436  } else {
437  fprintf(out, "\n");
438  }
439  fprintf(out, " },\n");
440  ck_ir_v2_emit_memory_plan(out, graph, plan);
441  }
442 
443  fprintf(out, " \"buffers\": [\n");
444  for (int i = 0; i < graph->num_buffers; ++i) {
445  const CKIRV2Buffer *buf = &graph->buffers[i];
446  fprintf(out, " {\n");
447  fprintf(out, " \"name\": \"%s\",\n", buf->name ? buf->name : "");
448  fprintf(out, " \"scope\": \"%s\",\n", ck_ir_v2_scope_name(buf->scope));
449  fprintf(out, " \"role\": \"%s\",\n", ck_ir_v2_role_name(buf->role));
450  fprintf(out, " \"dtype\": \"%s\",\n", ck_ir_v2_dtype_name(buf->dtype));
451  fprintf(out, " \"optional\": %d,\n", buf->optional ? 1 : 0);
452  fprintf(out, " \"shape\": ");
453  ck_ir_v2_emit_shape(out, buf->shape);
454  fprintf(out, ",\n");
455  fprintf(out, " \"resolved_shape\": ");
456  ck_ir_v2_emit_resolved_shape(out, &graph->config, &align, buf->shape,
457  tokens_override);
458  fprintf(out, ",\n");
459  if (buf->alias_of) {
460  fprintf(out, " \"alias_of\": \"%s\",\n", buf->alias_of);
461  } else {
462  fprintf(out, " \"alias_of\": null,\n");
463  }
464  if (buf->condition) {
465  fprintf(out, " \"condition\": \"%s\"\n", buf->condition);
466  } else {
467  fprintf(out, " \"condition\": null\n");
468  }
469  fprintf(out, " }%s\n", (i + 1 == graph->num_buffers) ? "" : ",");
470  }
471  fprintf(out, " ],\n");
472 
473  fprintf(out, " \"nodes\": [\n");
474  for (int i = 0; i < graph->num_nodes; ++i) {
475  const CKIRV2Node *node = &graph->nodes[i];
476  fprintf(out, " {\n");
477  fprintf(out, " \"layer\": %d,\n", (int)node->layer);
478  fprintf(out, " \"op\": \"%s\",\n", node->op ? node->op : "");
479  fprintf(out, " \"kernel\": \"%s\",\n", node->kernel ? node->kernel : "");
480  fprintf(out, " \"kernel_variant\": \"%s\",\n", node->kernel ? node->kernel : "");
481  fprintf(out, " \"kernel_dtype\": \"%s\",\n", ck_ir_v2_dtype_name(node->kernel_dtype));
482  fprintf(out, " \"flags\": %u,\n", (unsigned)node->flags);
483  if (node->condition) {
484  fprintf(out, " \"condition\": \"%s\",\n", node->condition);
485  } else {
486  fprintf(out, " \"condition\": null,\n");
487  }
488  fprintf(out, " \"bindings\": [\n");
489  for (int b = 0; b < node->n_bindings; ++b) {
490  const CKIRV2Binding *bind = &node->bindings[b];
491  const char *buf_name = "";
492  if (bind->buffer >= 0 && bind->buffer < graph->num_buffers) {
493  buf_name = graph->buffers[bind->buffer].name ? graph->buffers[bind->buffer].name : "";
494  }
495  fprintf(out,
496  " {\"arg\": \"%s\", \"buffer\": \"%s\"}%s\n",
497  bind->arg ? bind->arg : "",
498  buf_name,
499  (b + 1 == node->n_bindings) ? "" : ",");
500  }
501  fprintf(out, " ]\n");
502  fprintf(out, " }%s\n", (i + 1 == graph->num_nodes) ? "" : ",");
503  }
504  fprintf(out, " ]\n");
505  fprintf(out, "}\n");
506 
507  fclose(out);
508  return 0;
509 }
static void ck_ir_v2_resolve_align(const CKModelConfig *cfg, size_t alignment_bytes, CKIRV2AlignInfo *align)
static void ck_ir_v2_emit_resolved_shape(FILE *out, const CKModelConfig *cfg, const CKIRV2AlignInfo *align, const CKDimToken *shape, int tokens_override)
static void ck_ir_v2_emit_dimensions(FILE *out, const CKModelConfig *cfg, const CKIRV2AlignInfo *align, int tokens_override)
static const char * ck_ir_v2_dtype_name(CKDataType dtype)
Definition: ckernel_ir_v2.c:86
static const char * ck_ir_v2_scope_name(CKBufferScope scope)
Definition: ckernel_ir_v2.c:54
static const char * ck_ir_v2_role_name(CKBufferRole role)
Definition: ckernel_ir_v2.c:66
static void ck_ir_v2_emit_memory_plan(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan)
static int ck_ir_v2_emit_shape(FILE *out, const CKDimToken *shape)

References CKIRV2Buffer::alias_of, CKIRV2Binding::arg, CKIRV2Node::bindings, CKIRV2Binding::buffer, CKIRV2Graph::buffers, ck_ir_v2_dtype_name(), ck_ir_v2_emit_dimensions(), ck_ir_v2_emit_memory_plan(), ck_ir_v2_emit_resolved_shape(), ck_ir_v2_emit_shape(), ck_ir_v2_resolve_align(), ck_ir_v2_role_name(), ck_ir_v2_scope_name(), CK_MEM_PLAN_DEFAULT_ALIGN, CKIRV2Buffer::condition, CKIRV2Node::condition, CKIRV2Graph::config, CKModelConfig::context_window, CKIRV2Buffer::dtype, CKIRV2Node::flags, CKIRV2Graph::fused_qkv, CKIRV2Graph::gated_mlp, CKIRV2Graph::has_pos_emb, CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKIRV2Node::kernel, CKIRV2Node::kernel_dtype, CKIRV2Node::layer, CKIRV2Node::n_bindings, CKIRV2Buffer::name, CKIRV2Graph::nodes, CKIRV2Graph::num_buffers, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, CKModelConfig::num_layers, CKIRV2Graph::num_nodes, CKIRV2Node::op, CKIRV2Buffer::optional, CKModelConfig::rms_norm_eps, CKIRV2Buffer::role, CKModelConfig::rope_theta, CKIRV2Buffer::scope, CKIRV2Buffer::shape, CKIRV2Graph::tie_word_embeddings, and CKModelConfig::vocab_size.

Referenced by ck_ir_v2_serialize_json(), and ck_ir_v2_serialize_json_with_plan().

◆ ck_ir_v2_serialize_json_with_plan()

int ck_ir_v2_serialize_json_with_plan ( const CKIRV2Graph graph,
const CKMemPlan plan,
const char *  mode,
int  tokens_override,
int  base_context_window,
const char *  path 
)

Definition at line 516 of file ckernel_ir_v2.c.

522 {
523  return ck_ir_v2_serialize_json_internal(graph, plan, mode,
524  tokens_override,
525  base_context_window,
526  path);
527 }

References ck_ir_v2_serialize_json_internal().

◆ ck_ir_v2_skip_string()

static const char* ck_ir_v2_skip_string ( const char *  cur,
const char *  end 
)
static

Definition at line 558 of file ckernel_ir_v2.c.

559 {
560  if (!cur || cur >= end || *cur != '"') {
561  return cur;
562  }
563  cur++;
564  while (cur < end) {
565  if (*cur == '\\' && (cur + 1) < end) {
566  cur += 2;
567  continue;
568  }
569  if (*cur == '"') {
570  return cur + 1;
571  }
572  cur++;
573  }
574  return end;
575 }

References end.

Referenced by ck_ir_v2_find_array_end(), and ck_ir_v2_next_object().