← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_codegen_v2_sections.c
Go to the documentation of this file.
2 
3 #include "ckernel_dtype.h"
4 
5 #include <stdio.h>
6 #include <string.h>
7 
8 #define CK_V2_SECTION_ALIGN 64
9 
10 typedef struct {
11  size_t aligned_embed;
12  size_t aligned_head;
13  size_t aligned_intermediate;
14  size_t aligned_context;
15 } CKV2AlignInfo;
16 
17 static size_t align_up_bytes(size_t n, size_t align)
18 {
19  if (align == 0) {
20  return n;
21  }
22  return (n + align - 1) & ~(align - 1);
23 }
24 
25 static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
26 {
27  size_t bytes = elems * elem_bytes;
28  bytes = align_up_bytes(bytes, align_bytes);
29  return bytes / elem_bytes;
30 }
31 
32 static CKV2AlignInfo compute_align(const CKModelConfig *cfg)
33 {
34  CKV2AlignInfo info = {0};
35  if (!cfg) {
36  return info;
37  }
38  size_t elem_bytes = sizeof(float);
39  size_t head_dim = (cfg->num_heads > 0) ? (size_t)(cfg->hidden_size / cfg->num_heads) : 0;
40  info.aligned_embed = align_up_elems((size_t)cfg->hidden_size, elem_bytes, CK_V2_SECTION_ALIGN);
41  info.aligned_head = align_up_elems(head_dim, elem_bytes, CK_V2_SECTION_ALIGN);
42  info.aligned_intermediate =
43  align_up_elems((size_t)cfg->intermediate_size, elem_bytes, CK_V2_SECTION_ALIGN);
44  info.aligned_context =
45  align_up_elems((size_t)cfg->context_window, elem_bytes, CK_V2_SECTION_ALIGN);
46  return info;
47 }
48 
49 static size_t resolve_dim(const CKModelConfig *cfg,
50  const CKV2AlignInfo *align,
51  CKDimKind kind)
52 {
53  switch (kind) {
54  case CK_DIM_TOKENS:
55  return cfg ? (size_t)cfg->context_window : 0;
56  case CK_DIM_EMBED:
57  return cfg ? (size_t)cfg->hidden_size : 0;
59  return align ? align->aligned_embed : 0;
60  case CK_DIM_HEAD_DIM:
61  return (cfg && cfg->num_heads > 0) ? (size_t)(cfg->hidden_size / cfg->num_heads) : 0;
63  return align ? align->aligned_head : 0;
64  case CK_DIM_NUM_HEADS:
65  return cfg ? (size_t)cfg->num_heads : 0;
67  return cfg ? (size_t)cfg->num_kv_heads : 0;
68  case CK_DIM_ALIGNED_CTX:
69  return align ? align->aligned_context : 0;
71  return cfg ? (size_t)cfg->intermediate_size : 0;
73  return align ? align->aligned_intermediate : 0;
74  case CK_DIM_VOCAB:
75  return cfg ? (size_t)cfg->vocab_size : 0;
76  case CK_DIM_END:
77  default:
78  return 0;
79  }
80 }
81 
82 static size_t resolve_shape_elems(const CKModelConfig *cfg,
83  const CKV2AlignInfo *align,
84  const CKDimToken *shape)
85 {
86  size_t total = 1;
87  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
88  if (shape[i].dim == CK_DIM_END) {
89  break;
90  }
91  size_t dim = resolve_dim(cfg, align, shape[i].dim);
92  size_t mult = (size_t)(shape[i].mult > 0 ? shape[i].mult : 1);
93  size_t div = (size_t)(shape[i].div > 0 ? shape[i].div : 1);
94  if (div == 0) div = 1;
95  total = total * dim * mult / div;
96  }
97  return total;
98 }
99 
100 static size_t buffer_bytes(const CKIRV2Buffer *buf,
101  const CKModelConfig *cfg,
102  const CKV2AlignInfo *align)
103 {
104  size_t elems = resolve_shape_elems(cfg, align, buf->shape);
105  return ck_dtype_row_bytes(buf->dtype, elems);
106 }
107 
108 static int is_footer_global(const char *name)
109 {
110  if (!name) {
111  return 0;
112  }
113  if (strncmp(name, "final_", 6) == 0) {
114  return 1;
115  }
116  if (strcmp(name, "lm_head_weight") == 0) {
117  return 1;
118  }
119  if (strcmp(name, "logits") == 0) {
120  return 1;
121  }
122  if (strncmp(name, "d_final_", 8) == 0) {
123  return 1;
124  }
125  if (strcmp(name, "d_logits") == 0) {
126  return 1;
127  }
128  return 0;
129 }
130 
131 static void emit_span_field(FILE *out, const char *label)
132 {
133  fprintf(out, " CKV2Span %s;\n", label);
134 }
135 
136 static void emit_span_value(FILE *out, const char *label, size_t offset, size_t size, int comma)
137 {
138  fprintf(out, " .%s = { %zu, %zu }%s\n", label, offset, size, comma ? "," : "");
139 }
140 
142 {
143  return role == CK_ROLE_INPUT || role == CK_ROLE_OUTPUT ||
144  role == CK_ROLE_ACTIVATION || role == CK_ROLE_SCRATCH;
145 }
146 
147 static void emit_header_fields(FILE *out,
148  const CKIRV2Graph *graph,
149  CKBufferRole role_filter,
150  int activation_group)
151 {
152  for (int i = 0; i < graph->num_buffers; ++i) {
153  const CKIRV2Buffer *buf = &graph->buffers[i];
154  if (buf->scope != CK_SCOPE_GLOBAL) {
155  continue;
156  }
157  if (activation_group) {
158  if (!is_activation_role(buf->role)) {
159  continue;
160  }
161  } else if (buf->role != role_filter) {
162  continue;
163  }
164  if (is_footer_global(buf->name)) {
165  continue;
166  }
167  emit_span_field(out, buf->name);
168  }
169 }
170 
171 static void emit_body_fields(FILE *out,
172  const CKIRV2Graph *graph,
173  CKBufferRole role_filter,
174  int activation_group)
175 {
176  for (int i = 0; i < graph->num_buffers; ++i) {
177  const CKIRV2Buffer *buf = &graph->buffers[i];
178  if (buf->scope != CK_SCOPE_LAYER) {
179  continue;
180  }
181  if (activation_group) {
182  if (!is_activation_role(buf->role)) {
183  continue;
184  }
185  } else if (buf->role != role_filter) {
186  continue;
187  }
188  emit_span_field(out, buf->name);
189  }
190 }
191 
192 static void emit_footer_fields(FILE *out,
193  const CKIRV2Graph *graph,
194  CKBufferRole role_filter,
195  int activation_group)
196 {
197  for (int i = 0; i < graph->num_buffers; ++i) {
198  const CKIRV2Buffer *buf = &graph->buffers[i];
199  if (buf->scope != CK_SCOPE_GLOBAL) {
200  continue;
201  }
202  if (activation_group) {
203  if (!is_activation_role(buf->role)) {
204  continue;
205  }
206  } else if (buf->role != role_filter) {
207  continue;
208  }
209  if (!is_footer_global(buf->name)) {
210  continue;
211  }
212  emit_span_field(out, buf->name);
213  }
214 }
215 
216 static size_t plan_size(const CKMemPlan *plan, int idx)
217 {
218  if (!plan || !plan->spans || idx < 0 || idx >= plan->num_spans) {
219  return 0;
220  }
221  return plan->spans[idx].size_bytes;
222 }
223 
224 static void emit_header_values(FILE *out,
225  const CKIRV2Graph *graph,
226  const CKMemPlan *plan,
227  CKBufferRole role_filter,
228  int activation_group,
229  size_t *offset)
230 {
231  for (int i = 0; i < graph->num_buffers; ++i) {
232  const CKIRV2Buffer *buf = &graph->buffers[i];
233  if (buf->scope != CK_SCOPE_GLOBAL) {
234  continue;
235  }
236  if (activation_group) {
237  if (!is_activation_role(buf->role)) {
238  continue;
239  }
240  } else if (buf->role != role_filter) {
241  continue;
242  }
243  if (is_footer_global(buf->name)) {
244  continue;
245  }
246  size_t bytes = plan_size(plan, i);
247  *offset = align_up_bytes(*offset, CK_V2_SECTION_ALIGN);
248  emit_span_value(out, buf->name, *offset, bytes, 1);
249  *offset += bytes;
250  }
251 }
252 
253 static size_t emit_body_values(FILE *out,
254  const CKIRV2Graph *graph,
255  const CKMemPlan *plan,
256  CKBufferRole role_filter,
257  int activation_group)
258 {
259  size_t layer_offset = 0;
260  for (int i = 0; i < graph->num_buffers; ++i) {
261  const CKIRV2Buffer *buf = &graph->buffers[i];
262  if (buf->scope != CK_SCOPE_LAYER) {
263  continue;
264  }
265  if (activation_group) {
266  if (!is_activation_role(buf->role)) {
267  continue;
268  }
269  } else if (buf->role != role_filter) {
270  continue;
271  }
272  size_t bytes = plan_size(plan, i);
273  layer_offset = align_up_bytes(layer_offset, CK_V2_SECTION_ALIGN);
274  emit_span_value(out, buf->name, layer_offset, bytes, 1);
275  layer_offset += bytes;
276  }
277  return layer_offset;
278 }
279 
280 static void emit_footer_values(FILE *out,
281  const CKIRV2Graph *graph,
282  const CKMemPlan *plan,
283  CKBufferRole role_filter,
284  int activation_group,
285  size_t *offset)
286 {
287  for (int i = 0; i < graph->num_buffers; ++i) {
288  const CKIRV2Buffer *buf = &graph->buffers[i];
289  if (buf->scope != CK_SCOPE_GLOBAL) {
290  continue;
291  }
292  if (activation_group) {
293  if (!is_activation_role(buf->role)) {
294  continue;
295  }
296  } else if (buf->role != role_filter) {
297  continue;
298  }
299  if (!is_footer_global(buf->name)) {
300  continue;
301  }
302  size_t bytes = plan_size(plan, i);
303  *offset = align_up_bytes(*offset, CK_V2_SECTION_ALIGN);
304  emit_span_value(out, buf->name, *offset, bytes, 1);
305  *offset += bytes;
306  }
307 }
308 
310  const CKIRV2Graph *graph,
311  const CKMemPlan *prefill_plan,
312  const CKMemPlan *decode_plan,
313  const CKMemPlan *backward_plan)
314 {
315  if (!out || !graph) {
316  return;
317  }
318 
319  const int L = graph->config.num_layers;
320 
321  fprintf(out,
322  "typedef struct {\n"
323  " size_t offset;\n"
324  " size_t size;\n"
325  "} CKV2Span;\n\n");
326 
327  fprintf(out, "typedef struct {\n");
328  emit_header_fields(out, graph, CK_ROLE_WEIGHT, 0);
329  fprintf(out, "} CKV2HeaderWeights;\n\n");
330 
331  fprintf(out, "typedef struct {\n");
332  emit_body_fields(out, graph, CK_ROLE_WEIGHT, 0);
333  fprintf(out, "} CKV2LayerWeights;\n\n");
334 
335  fprintf(out, "typedef struct {\n");
336  emit_footer_fields(out, graph, CK_ROLE_WEIGHT, 0);
337  fprintf(out, "} CKV2FooterWeights;\n\n");
338 
339  fprintf(out, "typedef struct {\n");
340  emit_header_fields(out, graph, CK_ROLE_ACTIVATION, 1);
341  fprintf(out, "} CKV2HeaderActivations;\n\n");
342 
343  fprintf(out, "typedef struct {\n");
344  emit_body_fields(out, graph, CK_ROLE_ACTIVATION, 1);
345  fprintf(out, "} CKV2LayerActivations;\n\n");
346 
347  fprintf(out, "typedef struct {\n");
348  emit_footer_fields(out, graph, CK_ROLE_ACTIVATION, 1);
349  fprintf(out, "} CKV2FooterActivations;\n\n");
350 
351  fprintf(out, "typedef struct {\n");
352  emit_header_fields(out, graph, CK_ROLE_GRAD, 0);
353  fprintf(out, "} CKV2HeaderGrads;\n\n");
354 
355  fprintf(out, "typedef struct {\n");
356  emit_body_fields(out, graph, CK_ROLE_GRAD, 0);
357  fprintf(out, "} CKV2LayerGrads;\n\n");
358 
359  fprintf(out, "typedef struct {\n");
360  emit_footer_fields(out, graph, CK_ROLE_GRAD, 0);
361  fprintf(out, "} CKV2FooterGrads;\n\n");
362 
363  fprintf(out,
364  "typedef struct {\n"
365  " CKV2HeaderWeights header;\n"
366  " CKV2LayerWeights body;\n"
367  " CKV2FooterWeights footer;\n"
368  " size_t layer_stride_bytes;\n"
369  " size_t total_bytes;\n"
370  "} CKV2WeightLayout;\n\n");
371 
372  fprintf(out,
373  "typedef struct {\n"
374  " CKV2HeaderActivations header;\n"
375  " CKV2LayerActivations body;\n"
376  " CKV2FooterActivations footer;\n"
377  " size_t layer_stride_bytes;\n"
378  " size_t total_bytes;\n"
379  "} CKV2ActivationLayout;\n\n");
380 
381  fprintf(out,
382  "typedef struct {\n"
383  " CKV2HeaderGrads header;\n"
384  " CKV2LayerGrads body;\n"
385  " CKV2FooterGrads footer;\n"
386  " size_t layer_stride_bytes;\n"
387  " size_t total_bytes;\n"
388  "} CKV2GradLayout;\n\n");
389 
390  fprintf(out,
391  "typedef struct {\n"
392  " int num_layers;\n"
393  " int context_window;\n"
394  " int hidden_size;\n"
395  " int intermediate_size;\n"
396  " int num_heads;\n"
397  " int num_kv_heads;\n"
398  " int head_dim;\n"
399  " int vocab_size;\n"
400  "} CKV2ModelConfig;\n\n");
401 
402  fprintf(out,
403  "typedef struct {\n"
404  " CKV2WeightLayout weights;\n"
405  " CKV2ActivationLayout activations;\n"
406  " CKV2GradLayout grads;\n"
407  "} CKV2SectionLayout;\n\n");
408 
409  fprintf(out,
410  "typedef struct {\n"
411  " CKV2SectionLayout prefill;\n"
412  " CKV2SectionLayout decode;\n"
413  " CKV2SectionLayout backward;\n"
414  "} CKV2SectionLayouts;\n\n");
415 
416  fprintf(out,
417  "typedef struct {\n"
418  " CKV2ModelConfig config;\n"
419  " CKV2SectionLayouts decoder;\n"
420  "} CKV2RuntimeLayout;\n\n");
421 
422  fprintf(out, "static const CKV2RuntimeLayout ck_v2_layout = {\n");
423  fprintf(out, " .config = {\n");
424  fprintf(out, " .num_layers = %d,\n", graph->config.num_layers);
425  fprintf(out, " .context_window = %d,\n", graph->config.context_window);
426  fprintf(out, " .hidden_size = %d,\n", graph->config.hidden_size);
427  fprintf(out, " .intermediate_size = %d,\n", graph->config.intermediate_size);
428  fprintf(out, " .num_heads = %d,\n", graph->config.num_heads);
429  fprintf(out, " .num_kv_heads = %d,\n", graph->config.num_kv_heads);
430  fprintf(out, " .head_dim = %d,\n",
431  graph->config.num_heads > 0 ? graph->config.hidden_size / graph->config.num_heads : 0);
432  fprintf(out, " .vocab_size = %d\n", graph->config.vocab_size);
433  fprintf(out, " },\n");
434  fprintf(out, " .decoder = {\n");
435 
436  const CKMemPlan *plans[3] = { prefill_plan, decode_plan, backward_plan };
437  const char *modes[3] = { "prefill", "decode", "backward" };
438 
439  for (int mode = 0; mode < 3; ++mode) {
440  const CKMemPlan *plan = plans[mode];
441  fprintf(out, " .%s = {\n", modes[mode]);
442 
443  size_t offset = 0;
444  size_t header_end = 0;
445  size_t layer_stride = 0;
446  size_t footer_base = 0;
447 
448  fprintf(out, " .weights = {\n");
449  fprintf(out, " .header = {\n");
450  emit_header_values(out, graph, prefill_plan, CK_ROLE_WEIGHT, 0, &offset);
451  header_end = offset;
452  fprintf(out, " },\n");
453 
454  fprintf(out, " .body = {\n");
455  layer_stride = emit_body_values(out, graph, prefill_plan, CK_ROLE_WEIGHT, 0);
456  fprintf(out, " },\n");
457 
458  footer_base = header_end + layer_stride * (size_t)L;
459  offset = footer_base;
460  fprintf(out, " .footer = {\n");
461  emit_footer_values(out, graph, prefill_plan, CK_ROLE_WEIGHT, 0, &offset);
462  fprintf(out, " },\n");
463 
464  fprintf(out, " .layer_stride_bytes = %zu,\n", layer_stride);
465  fprintf(out, " .total_bytes = %zu\n", offset);
466  fprintf(out, " },\n");
467 
468  offset = 0;
469  header_end = 0;
470  layer_stride = 0;
471  footer_base = 0;
472 
473  fprintf(out, " .activations = {\n");
474  fprintf(out, " .header = {\n");
475  emit_header_values(out, graph, plan, CK_ROLE_ACTIVATION, 1, &offset);
476  header_end = offset;
477  fprintf(out, " },\n");
478 
479  fprintf(out, " .body = {\n");
480  layer_stride = emit_body_values(out, graph, plan, CK_ROLE_ACTIVATION, 1);
481  fprintf(out, " },\n");
482 
483  footer_base = header_end + layer_stride * (size_t)L;
484  offset = footer_base;
485  fprintf(out, " .footer = {\n");
486  emit_footer_values(out, graph, plan, CK_ROLE_ACTIVATION, 1, &offset);
487  fprintf(out, " },\n");
488 
489  fprintf(out, " .layer_stride_bytes = %zu,\n", layer_stride);
490  fprintf(out, " .total_bytes = %zu\n", offset);
491  fprintf(out, " },\n");
492 
493  offset = 0;
494  header_end = 0;
495  layer_stride = 0;
496  footer_base = 0;
497 
498  fprintf(out, " .grads = {\n");
499  fprintf(out, " .header = {\n");
500  emit_header_values(out, graph, plan, CK_ROLE_GRAD, 0, &offset);
501  header_end = offset;
502  fprintf(out, " },\n");
503 
504  fprintf(out, " .body = {\n");
505  layer_stride = emit_body_values(out, graph, plan, CK_ROLE_GRAD, 0);
506  fprintf(out, " },\n");
507 
508  footer_base = header_end + layer_stride * (size_t)L;
509  offset = footer_base;
510  fprintf(out, " .footer = {\n");
511  emit_footer_values(out, graph, plan, CK_ROLE_GRAD, 0, &offset);
512  fprintf(out, " },\n");
513 
514  fprintf(out, " .layer_stride_bytes = %zu,\n", layer_stride);
515  fprintf(out, " .total_bytes = %zu\n", offset);
516  fprintf(out, " }\n");
517 
518  fprintf(out, " }%s\n", mode == 2 ? "" : ",");
519  }
520 
521  fprintf(out, " }\n");
522  fprintf(out, "};\n\n");
523 }
static CKV2AlignInfo compute_align(const CKModelConfig *cfg)
static void emit_footer_fields(FILE *out, const CKIRV2Graph *graph, CKBufferRole role_filter, int activation_group)
#define CK_V2_SECTION_ALIGN
static size_t align_up_bytes(size_t n, size_t align)
static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
static size_t resolve_dim(const CKModelConfig *cfg, const CKV2AlignInfo *align, CKDimKind kind)
static size_t buffer_bytes(const CKIRV2Buffer *buf, const CKModelConfig *cfg, const CKV2AlignInfo *align)
void ck_codegen_v2_emit_sections(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *prefill_plan, const CKMemPlan *decode_plan, const CKMemPlan *backward_plan)
static int is_activation_role(CKBufferRole role)
static size_t emit_body_values(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan, CKBufferRole role_filter, int activation_group)
static size_t plan_size(const CKMemPlan *plan, int idx)
static void emit_header_values(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan, CKBufferRole role_filter, int activation_group, size_t *offset)
static void emit_body_fields(FILE *out, const CKIRV2Graph *graph, CKBufferRole role_filter, int activation_group)
static int is_footer_global(const char *name)
static void emit_span_value(FILE *out, const char *label, size_t offset, size_t size, int comma)
static size_t resolve_shape_elems(const CKModelConfig *cfg, const CKV2AlignInfo *align, const CKDimToken *shape)
static void emit_header_fields(FILE *out, const CKIRV2Graph *graph, CKBufferRole role_filter, int activation_group)
static void emit_span_field(FILE *out, const char *label)
static void emit_footer_values(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan, CKBufferRole role_filter, int activation_group, size_t *offset)
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
#define CK_IR_V2_MAX_DIMS
Definition: ckernel_ir_v2.h:13
@ CK_ROLE_WEIGHT
@ CK_ROLE_SCRATCH
@ CK_ROLE_GRAD
@ CK_ROLE_ACTIVATION
@ CK_ROLE_INPUT
@ CK_ROLE_OUTPUT
@ CK_DIM_ALIGNED_INTERMEDIATE
@ CK_DIM_NUM_HEADS
@ CK_DIM_ALIGNED_EMBED
@ CK_DIM_TOKENS
@ CK_DIM_INTERMEDIATE
@ CK_DIM_ALIGNED_CTX
@ CK_DIM_END
@ CK_DIM_ALIGNED_HEAD
@ CK_DIM_HEAD_DIM
@ CK_DIM_NUM_KV_HEADS
@ CK_DIM_VOCAB
@ CK_DIM_EMBED
@ CK_SCOPE_LAYER
@ CK_SCOPE_GLOBAL
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
CKBufferRole role
Definition: ckernel_ir_v2.h:27
CKBufferScope scope
Definition: ckernel_ir_v2.h:26
CKDataType dtype
Definition: ckernel_ir_v2.h:28
CKModelConfig config
Definition: ckernel_ir_v2.h:56
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62
CKMemSpan * spans
size_t size_bytes
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37