8 #define CK_V2_SECTION_ALIGN 64
13 size_t aligned_intermediate;
14 size_t aligned_context;
22 return (n + align - 1) & ~(align - 1);
25 static size_t align_up_elems(
size_t elems,
size_t elem_bytes,
size_t align_bytes)
27 size_t bytes = elems * elem_bytes;
29 return bytes / elem_bytes;
34 CKV2AlignInfo info = {0};
38 size_t elem_bytes =
sizeof(float);
42 info.aligned_intermediate =
44 info.aligned_context =
50 const CKV2AlignInfo *align,
59 return align ? align->aligned_embed : 0;
63 return align ? align->aligned_head : 0;
69 return align ? align->aligned_context : 0;
73 return align ? align->aligned_intermediate : 0;
83 const CKV2AlignInfo *align,
92 size_t mult = (size_t)(shape[i].mult > 0 ? shape[i].mult : 1);
93 size_t div = (size_t)(shape[i].div > 0 ? shape[i].div : 1);
94 if (div == 0) div = 1;
95 total = total * dim * mult / div;
102 const CKV2AlignInfo *align)
113 if (strncmp(name,
"final_", 6) == 0) {
116 if (strcmp(name,
"lm_head_weight") == 0) {
119 if (strcmp(name,
"logits") == 0) {
122 if (strncmp(name,
"d_final_", 8) == 0) {
125 if (strcmp(name,
"d_logits") == 0) {
133 fprintf(out,
" CKV2Span %s;\n", label);
136 static void emit_span_value(FILE *out,
const char *label,
size_t offset,
size_t size,
int comma)
138 fprintf(out,
" .%s = { %zu, %zu }%s\n", label, offset, size, comma ?
"," :
"");
150 int activation_group)
157 if (activation_group) {
161 }
else if (buf->
role != role_filter) {
174 int activation_group)
181 if (activation_group) {
185 }
else if (buf->
role != role_filter) {
195 int activation_group)
202 if (activation_group) {
206 }
else if (buf->
role != role_filter) {
228 int activation_group,
236 if (activation_group) {
240 }
else if (buf->
role != role_filter) {
257 int activation_group)
259 size_t layer_offset = 0;
265 if (activation_group) {
269 }
else if (buf->
role != role_filter) {
275 layer_offset += bytes;
284 int activation_group,
292 if (activation_group) {
296 }
else if (buf->
role != role_filter) {
315 if (!out || !graph) {
327 fprintf(out,
"typedef struct {\n");
329 fprintf(out,
"} CKV2HeaderWeights;\n\n");
331 fprintf(out,
"typedef struct {\n");
333 fprintf(out,
"} CKV2LayerWeights;\n\n");
335 fprintf(out,
"typedef struct {\n");
337 fprintf(out,
"} CKV2FooterWeights;\n\n");
339 fprintf(out,
"typedef struct {\n");
341 fprintf(out,
"} CKV2HeaderActivations;\n\n");
343 fprintf(out,
"typedef struct {\n");
345 fprintf(out,
"} CKV2LayerActivations;\n\n");
347 fprintf(out,
"typedef struct {\n");
349 fprintf(out,
"} CKV2FooterActivations;\n\n");
351 fprintf(out,
"typedef struct {\n");
353 fprintf(out,
"} CKV2HeaderGrads;\n\n");
355 fprintf(out,
"typedef struct {\n");
357 fprintf(out,
"} CKV2LayerGrads;\n\n");
359 fprintf(out,
"typedef struct {\n");
361 fprintf(out,
"} CKV2FooterGrads;\n\n");
365 " CKV2HeaderWeights header;\n"
366 " CKV2LayerWeights body;\n"
367 " CKV2FooterWeights footer;\n"
368 " size_t layer_stride_bytes;\n"
369 " size_t total_bytes;\n"
370 "} CKV2WeightLayout;\n\n");
374 " CKV2HeaderActivations header;\n"
375 " CKV2LayerActivations body;\n"
376 " CKV2FooterActivations footer;\n"
377 " size_t layer_stride_bytes;\n"
378 " size_t total_bytes;\n"
379 "} CKV2ActivationLayout;\n\n");
383 " CKV2HeaderGrads header;\n"
384 " CKV2LayerGrads body;\n"
385 " CKV2FooterGrads footer;\n"
386 " size_t layer_stride_bytes;\n"
387 " size_t total_bytes;\n"
388 "} CKV2GradLayout;\n\n");
393 " int context_window;\n"
394 " int hidden_size;\n"
395 " int intermediate_size;\n"
397 " int num_kv_heads;\n"
400 "} CKV2ModelConfig;\n\n");
404 " CKV2WeightLayout weights;\n"
405 " CKV2ActivationLayout activations;\n"
406 " CKV2GradLayout grads;\n"
407 "} CKV2SectionLayout;\n\n");
411 " CKV2SectionLayout prefill;\n"
412 " CKV2SectionLayout decode;\n"
413 " CKV2SectionLayout backward;\n"
414 "} CKV2SectionLayouts;\n\n");
418 " CKV2ModelConfig config;\n"
419 " CKV2SectionLayouts decoder;\n"
420 "} CKV2RuntimeLayout;\n\n");
422 fprintf(out,
"static const CKV2RuntimeLayout ck_v2_layout = {\n");
423 fprintf(out,
" .config = {\n");
430 fprintf(out,
" .head_dim = %d,\n",
433 fprintf(out,
" },\n");
434 fprintf(out,
" .decoder = {\n");
436 const CKMemPlan *plans[3] = { prefill_plan, decode_plan, backward_plan };
437 const char *modes[3] = {
"prefill",
"decode",
"backward" };
439 for (
int mode = 0; mode < 3; ++mode) {
441 fprintf(out,
" .%s = {\n", modes[mode]);
444 size_t header_end = 0;
445 size_t layer_stride = 0;
446 size_t footer_base = 0;
448 fprintf(out,
" .weights = {\n");
449 fprintf(out,
" .header = {\n");
452 fprintf(out,
" },\n");
454 fprintf(out,
" .body = {\n");
456 fprintf(out,
" },\n");
458 footer_base = header_end + layer_stride * (size_t)L;
459 offset = footer_base;
460 fprintf(out,
" .footer = {\n");
462 fprintf(out,
" },\n");
464 fprintf(out,
" .layer_stride_bytes = %zu,\n", layer_stride);
465 fprintf(out,
" .total_bytes = %zu\n", offset);
466 fprintf(out,
" },\n");
473 fprintf(out,
" .activations = {\n");
474 fprintf(out,
" .header = {\n");
477 fprintf(out,
" },\n");
479 fprintf(out,
" .body = {\n");
481 fprintf(out,
" },\n");
483 footer_base = header_end + layer_stride * (size_t)L;
484 offset = footer_base;
485 fprintf(out,
" .footer = {\n");
487 fprintf(out,
" },\n");
489 fprintf(out,
" .layer_stride_bytes = %zu,\n", layer_stride);
490 fprintf(out,
" .total_bytes = %zu\n", offset);
491 fprintf(out,
" },\n");
498 fprintf(out,
" .grads = {\n");
499 fprintf(out,
" .header = {\n");
502 fprintf(out,
" },\n");
504 fprintf(out,
" .body = {\n");
506 fprintf(out,
" },\n");
508 footer_base = header_end + layer_stride * (size_t)L;
509 offset = footer_base;
510 fprintf(out,
" .footer = {\n");
512 fprintf(out,
" },\n");
514 fprintf(out,
" .layer_stride_bytes = %zu,\n", layer_stride);
515 fprintf(out,
" .total_bytes = %zu\n", offset);
516 fprintf(out,
" }\n");
518 fprintf(out,
" }%s\n", mode == 2 ?
"" :
",");
521 fprintf(out,
" }\n");
522 fprintf(out,
"};\n\n");
static CKV2AlignInfo compute_align(const CKModelConfig *cfg)
static void emit_footer_fields(FILE *out, const CKIRV2Graph *graph, CKBufferRole role_filter, int activation_group)
#define CK_V2_SECTION_ALIGN
static size_t align_up_bytes(size_t n, size_t align)
static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
static size_t resolve_dim(const CKModelConfig *cfg, const CKV2AlignInfo *align, CKDimKind kind)
static size_t buffer_bytes(const CKIRV2Buffer *buf, const CKModelConfig *cfg, const CKV2AlignInfo *align)
void ck_codegen_v2_emit_sections(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *prefill_plan, const CKMemPlan *decode_plan, const CKMemPlan *backward_plan)
static int is_activation_role(CKBufferRole role)
static size_t emit_body_values(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan, CKBufferRole role_filter, int activation_group)
static size_t plan_size(const CKMemPlan *plan, int idx)
static void emit_header_values(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan, CKBufferRole role_filter, int activation_group, size_t *offset)
static void emit_body_fields(FILE *out, const CKIRV2Graph *graph, CKBufferRole role_filter, int activation_group)
static int is_footer_global(const char *name)
static void emit_span_value(FILE *out, const char *label, size_t offset, size_t size, int comma)
static size_t resolve_shape_elems(const CKModelConfig *cfg, const CKV2AlignInfo *align, const CKDimToken *shape)
static void emit_header_fields(FILE *out, const CKIRV2Graph *graph, CKBufferRole role_filter, int activation_group)
static void emit_span_field(FILE *out, const char *label)
static void emit_footer_values(FILE *out, const CKIRV2Graph *graph, const CKMemPlan *plan, CKBufferRole role_filter, int activation_group, size_t *offset)
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
#define CK_IR_V2_MAX_DIMS
@ CK_DIM_ALIGNED_INTERMEDIATE