← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_mem_plan.c
Go to the documentation of this file.
1 #include "ckernel_mem_plan.h"
2 
3 #include "ckernel_dtype.h"
4 
5 #include <stdlib.h>
6 #include <string.h>
7 
8 typedef struct {
9  size_t aligned_embed;
10  size_t aligned_head;
11  size_t aligned_intermediate;
12  size_t aligned_context;
13 } CKIRV2AlignInfo;
14 
15 static size_t align_up_bytes(size_t n, size_t align)
16 {
17  if (align == 0) return n;
18  return (n + align - 1) & ~(align - 1);
19 }
20 
21 static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
22 {
23  size_t bytes = elems * elem_bytes;
24  bytes = align_up_bytes(bytes, align_bytes);
25  return bytes / elem_bytes;
26 }
27 
28 static size_t resolve_dim(const CKModelConfig *cfg,
29  const CKIRV2AlignInfo *align,
30  CKDimKind kind,
31  int tokens_override)
32 {
33  switch (kind) {
34  case CK_DIM_TOKENS:
35  if (tokens_override >= 0) {
36  return (size_t)tokens_override;
37  }
38  return (size_t)cfg->context_window;
39  case CK_DIM_EMBED:
40  return (size_t)cfg->hidden_size;
42  return align->aligned_embed;
43  case CK_DIM_HEAD_DIM:
44  return (size_t)(cfg->hidden_size / cfg->num_heads);
46  return align->aligned_head;
47  case CK_DIM_NUM_HEADS:
48  return (size_t)cfg->num_heads;
50  return (size_t)cfg->num_kv_heads;
51  case CK_DIM_ALIGNED_CTX:
52  return align->aligned_context;
54  return (size_t)cfg->intermediate_size;
56  return align->aligned_intermediate;
57  case CK_DIM_VOCAB:
58  return (size_t)cfg->vocab_size;
59  case CK_DIM_END:
60  default:
61  return 0;
62  }
63 }
64 
65 static size_t resolve_shape_elems(const CKModelConfig *cfg,
66  const CKIRV2AlignInfo *align,
67  const CKDimToken *shape,
68  int tokens_override)
69 {
70  size_t total = 1;
71  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
72  if (shape[i].dim == CK_DIM_END) {
73  break;
74  }
75  size_t dim = resolve_dim(cfg, align, shape[i].dim, tokens_override);
76  size_t mult = (size_t)(shape[i].mult > 0 ? shape[i].mult : 1);
77  size_t div = (size_t)(shape[i].div > 0 ? shape[i].div : 1);
78  if (div == 0) div = 1;
79  total = total * dim * mult / div;
80  }
81  return total;
82 }
83 
85 {
86  switch (role) {
87  case CK_ROLE_WEIGHT:
88  return CK_MEM_ARENA_WEIGHTS;
89  case CK_ROLE_GRAD:
90  return CK_MEM_ARENA_GRADS;
91  case CK_ROLE_SCRATCH:
92  case CK_ROLE_ACTIVATION:
93  case CK_ROLE_INPUT:
94  case CK_ROLE_OUTPUT:
95  default:
97  }
98 }
99 
100 static int buffer_enabled(const CKIRV2Graph *graph,
101  const CKIRV2Buffer *buf,
102  int training_enabled)
103 {
104  if (!graph || !buf) {
105  return 0;
106  }
107  if (buf->condition && strcmp(buf->condition, "rope_theta") == 0) {
108  if (graph->config.rope_theta <= 0.0f) {
109  return 0;
110  }
111  }
112  if (buf->condition && strcmp(buf->condition, "rope_disabled") == 0) {
113  if (graph->config.rope_theta > 0.0f) {
114  return 0;
115  }
116  }
117  if (buf->condition && strcmp(buf->condition, "has_pos_emb") == 0) {
118  if (graph->has_pos_emb == 0) {
119  return 0;
120  }
121  }
122  if (buf->condition && strcmp(buf->condition, "training_enabled") == 0) {
123  if (!training_enabled) {
124  return 0;
125  }
126  }
127  if (buf->role == CK_ROLE_GRAD && !training_enabled) {
128  return 0;
129  }
130  return 1;
131 }
132 
133 static int find_buffer_by_name(const CKIRV2Graph *graph, const char *name)
134 {
135  if (!graph || !name) {
136  return -1;
137  }
138  for (int i = 0; i < graph->num_buffers; ++i) {
139  if (graph->buffers[i].name && strcmp(graph->buffers[i].name, name) == 0) {
140  return i;
141  }
142  }
143  return -1;
144 }
145 
146 static int build_plan(const CKIRV2Graph *graph,
147  CKMemPlan *plan,
148  size_t alignment_bytes,
149  int training_enabled,
150  int tokens_override)
151 {
152  if (!graph || !plan) {
153  return -1;
154  }
155  memset(plan, 0, sizeof(*plan));
156  if (alignment_bytes == 0) {
157  alignment_bytes = CK_MEM_PLAN_DEFAULT_ALIGN;
158  }
159  plan->alignment_bytes = alignment_bytes;
160 
161  plan->num_spans = graph->num_buffers;
162  plan->spans = (CKMemSpan *)calloc((size_t)graph->num_buffers, sizeof(CKMemSpan));
163  if (!plan->spans) {
164  return -1;
165  }
166 
167  size_t elem_bytes = sizeof(float);
168  CKIRV2AlignInfo align = {0};
169  align.aligned_embed = align_up_elems((size_t)graph->config.hidden_size, elem_bytes, alignment_bytes);
170  align.aligned_head = align_up_elems((size_t)(graph->config.hidden_size / graph->config.num_heads),
171  elem_bytes, alignment_bytes);
172  align.aligned_intermediate = align_up_elems((size_t)graph->config.intermediate_size,
173  elem_bytes, alignment_bytes);
174  align.aligned_context = align_up_elems((size_t)graph->config.context_window, elem_bytes, alignment_bytes);
175 
176  size_t arena_offsets[CK_MEM_ARENA_COUNT] = {0};
177 
178  for (int i = 0; i < graph->num_buffers; ++i) {
179  const CKIRV2Buffer *buf = &graph->buffers[i];
180  CKMemSpan *span = &plan->spans[i];
181  span->buffer_id = i;
182  span->arena = arena_for_role(buf->role);
183  span->offset_bytes = 0;
184  span->size_bytes = 0;
185 
186  if (!buffer_enabled(graph, buf, training_enabled)) {
187  continue;
188  }
189 
190  if (buf->alias_of) {
191  int alias_id = find_buffer_by_name(graph, buf->alias_of);
192  if (alias_id >= 0) {
193  span->offset_bytes = plan->spans[alias_id].offset_bytes;
194  span->size_bytes = plan->spans[alias_id].size_bytes;
195  span->arena = plan->spans[alias_id].arena;
196  continue;
197  }
198  }
199 
200  size_t n_elems = resolve_shape_elems(&graph->config, &align, buf->shape, tokens_override);
201  size_t bytes = ck_dtype_row_bytes(buf->dtype, n_elems);
202  size_t aligned = align_up_bytes(bytes, alignment_bytes);
203 
204  span->offset_bytes = arena_offsets[span->arena];
205  span->size_bytes = bytes;
206  arena_offsets[span->arena] += aligned;
207  }
208 
209  for (int i = 0; i < CK_MEM_ARENA_COUNT; ++i) {
210  plan->total_bytes[i] = arena_offsets[i];
211  }
212 
213  return 0;
214 }
215 
217  CKMemPlan *plan,
218  size_t alignment_bytes)
219 {
220  return build_plan(graph, plan, alignment_bytes, 0, -1);
221 }
222 
224  CKMemPlan *plan,
225  size_t alignment_bytes)
226 {
227  return build_plan(graph, plan, alignment_bytes, 1, -1);
228 }
229 
231  CKMemPlan *plan,
232  size_t alignment_bytes,
233  int tokens_override)
234 {
235  return build_plan(graph, plan, alignment_bytes, 0, tokens_override);
236 }
237 
239  CKMemPlan *plan,
240  size_t alignment_bytes,
241  int tokens_override)
242 {
243  return build_plan(graph, plan, alignment_bytes, 1, tokens_override);
244 }
245 
247 {
248  if (!plan) {
249  return;
250  }
251  free(plan->spans);
252  memset(plan, 0, sizeof(*plan));
253 }
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
#define CK_IR_V2_MAX_DIMS
Definition: ckernel_ir_v2.h:13
@ CK_ROLE_WEIGHT
@ CK_ROLE_SCRATCH
@ CK_ROLE_GRAD
@ CK_ROLE_ACTIVATION
@ CK_ROLE_INPUT
@ CK_ROLE_OUTPUT
@ CK_DIM_ALIGNED_INTERMEDIATE
@ CK_DIM_NUM_HEADS
@ CK_DIM_ALIGNED_EMBED
@ CK_DIM_TOKENS
@ CK_DIM_INTERMEDIATE
@ CK_DIM_ALIGNED_CTX
@ CK_DIM_END
@ CK_DIM_ALIGNED_HEAD
@ CK_DIM_HEAD_DIM
@ CK_DIM_NUM_KV_HEADS
@ CK_DIM_VOCAB
@ CK_DIM_EMBED
static size_t resolve_dim(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, CKDimKind kind, int tokens_override)
int ck_mem_plan_build_training(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes)
static size_t align_up_bytes(size_t n, size_t align)
static size_t resolve_shape_elems(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, const CKDimToken *shape, int tokens_override)
static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
static int build_plan(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int training_enabled, int tokens_override)
int ck_mem_plan_build_training_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
static int find_buffer_by_name(const CKIRV2Graph *graph, const char *name)
static CKMemArenaKind arena_for_role(CKBufferRole role)
void ck_mem_plan_free(CKMemPlan *plan)
int ck_mem_plan_build_inference(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes)
int ck_mem_plan_build_inference_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
static int buffer_enabled(const CKIRV2Graph *graph, const CKIRV2Buffer *buf, int training_enabled)
CKMemArenaKind
@ CK_MEM_ARENA_GRADS
@ CK_MEM_ARENA_COUNT
@ CK_MEM_ARENA_WEIGHTS
@ CK_MEM_ARENA_ACTIVATIONS
#define CK_MEM_PLAN_DEFAULT_ALIGN
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
CKBufferRole role
Definition: ckernel_ir_v2.h:27
char * condition
Definition: ckernel_ir_v2.h:32
char * alias_of
Definition: ckernel_ir_v2.h:31
CKDataType dtype
Definition: ckernel_ir_v2.h:28
CKModelConfig config
Definition: ckernel_ir_v2.h:56
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62
CKMemSpan * spans
size_t alignment_bytes
size_t total_bytes[CK_MEM_ARENA_COUNT]
size_t size_bytes
CKMemArenaKind arena
size_t offset_bytes
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37
float rope_theta
Definition: ckernel_ir.h:32