← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_mem_plan.c File Reference
#include "ckernel_mem_plan.h"
#include "ckernel_dtype.h"
#include <stdlib.h>
#include <string.h>

Go to the source code of this file.

Functions

static size_t align_up_bytes (size_t n, size_t align)
 
static size_t align_up_elems (size_t elems, size_t elem_bytes, size_t align_bytes)
 
static CKMemArenaKind arena_for_role (CKBufferRole role)
 
static int buffer_enabled (const CKIRV2Graph *graph, const CKIRV2Buffer *buf, int training_enabled)
 
static int build_plan (const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int training_enabled, int tokens_override)
 
int ck_mem_plan_build_inference (const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes)
 
int ck_mem_plan_build_inference_with_tokens (const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
 
int ck_mem_plan_build_training (const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes)
 
int ck_mem_plan_build_training_with_tokens (const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
 
void ck_mem_plan_free (CKMemPlan *plan)
 
static int find_buffer_by_name (const CKIRV2Graph *graph, const char *name)
 
static size_t resolve_dim (const CKModelConfig *cfg, const CKIRV2AlignInfo *align, CKDimKind kind, int tokens_override)
 
static size_t resolve_shape_elems (const CKModelConfig *cfg, const CKIRV2AlignInfo *align, const CKDimToken *shape, int tokens_override)
 

Function Documentation

◆ align_up_bytes()

static size_t align_up_bytes ( size_t  n,
size_t  align 
)
static

Definition at line 15 of file ckernel_mem_plan.c.

16 {
17  if (align == 0) return n;
18  return (n + align - 1) & ~(align - 1);
19 }

Referenced by align_up_elems(), and build_plan().

◆ align_up_elems()

static size_t align_up_elems ( size_t  elems,
size_t  elem_bytes,
size_t  align_bytes 
)
static

Definition at line 21 of file ckernel_mem_plan.c.

22 {
23  size_t bytes = elems * elem_bytes;
24  bytes = align_up_bytes(bytes, align_bytes);
25  return bytes / elem_bytes;
26 }
static size_t align_up_bytes(size_t n, size_t align)

References align_up_bytes().

Referenced by build_plan().

◆ arena_for_role()

static CKMemArenaKind arena_for_role ( CKBufferRole  role)
static

Definition at line 84 of file ckernel_mem_plan.c.

85 {
86  switch (role) {
87  case CK_ROLE_WEIGHT:
88  return CK_MEM_ARENA_WEIGHTS;
89  case CK_ROLE_GRAD:
90  return CK_MEM_ARENA_GRADS;
91  case CK_ROLE_SCRATCH:
92  case CK_ROLE_ACTIVATION:
93  case CK_ROLE_INPUT:
94  case CK_ROLE_OUTPUT:
95  default:
97  }
98 }
@ CK_ROLE_WEIGHT
@ CK_ROLE_SCRATCH
@ CK_ROLE_GRAD
@ CK_ROLE_ACTIVATION
@ CK_ROLE_INPUT
@ CK_ROLE_OUTPUT
@ CK_MEM_ARENA_GRADS
@ CK_MEM_ARENA_WEIGHTS
@ CK_MEM_ARENA_ACTIVATIONS

References CK_MEM_ARENA_ACTIVATIONS, CK_MEM_ARENA_GRADS, CK_MEM_ARENA_WEIGHTS, CK_ROLE_ACTIVATION, CK_ROLE_GRAD, CK_ROLE_INPUT, CK_ROLE_OUTPUT, CK_ROLE_SCRATCH, and CK_ROLE_WEIGHT.

Referenced by build_plan().

◆ buffer_enabled()

static int buffer_enabled ( const CKIRV2Graph graph,
const CKIRV2Buffer buf,
int  training_enabled 
)
static

Definition at line 100 of file ckernel_mem_plan.c.

103 {
104  if (!graph || !buf) {
105  return 0;
106  }
107  if (buf->condition && strcmp(buf->condition, "rope_theta") == 0) {
108  if (graph->config.rope_theta <= 0.0f) {
109  return 0;
110  }
111  }
112  if (buf->condition && strcmp(buf->condition, "rope_disabled") == 0) {
113  if (graph->config.rope_theta > 0.0f) {
114  return 0;
115  }
116  }
117  if (buf->condition && strcmp(buf->condition, "has_pos_emb") == 0) {
118  if (graph->has_pos_emb == 0) {
119  return 0;
120  }
121  }
122  if (buf->condition && strcmp(buf->condition, "training_enabled") == 0) {
123  if (!training_enabled) {
124  return 0;
125  }
126  }
127  if (buf->role == CK_ROLE_GRAD && !training_enabled) {
128  return 0;
129  }
130  return 1;
131 }
CKBufferRole role
Definition: ckernel_ir_v2.h:27
char * condition
Definition: ckernel_ir_v2.h:32
CKModelConfig config
Definition: ckernel_ir_v2.h:56
float rope_theta
Definition: ckernel_ir.h:32

References CK_ROLE_GRAD, CKIRV2Buffer::condition, CKIRV2Graph::config, CKIRV2Graph::has_pos_emb, CKIRV2Buffer::role, and CKModelConfig::rope_theta.

Referenced by build_plan().

◆ build_plan()

static int build_plan ( const CKIRV2Graph graph,
CKMemPlan plan,
size_t  alignment_bytes,
int  training_enabled,
int  tokens_override 
)
static

Definition at line 146 of file ckernel_mem_plan.c.

151 {
152  if (!graph || !plan) {
153  return -1;
154  }
155  memset(plan, 0, sizeof(*plan));
156  if (alignment_bytes == 0) {
157  alignment_bytes = CK_MEM_PLAN_DEFAULT_ALIGN;
158  }
159  plan->alignment_bytes = alignment_bytes;
160 
161  plan->num_spans = graph->num_buffers;
162  plan->spans = (CKMemSpan *)calloc((size_t)graph->num_buffers, sizeof(CKMemSpan));
163  if (!plan->spans) {
164  return -1;
165  }
166 
167  size_t elem_bytes = sizeof(float);
168  CKIRV2AlignInfo align = {0};
169  align.aligned_embed = align_up_elems((size_t)graph->config.hidden_size, elem_bytes, alignment_bytes);
170  align.aligned_head = align_up_elems((size_t)(graph->config.hidden_size / graph->config.num_heads),
171  elem_bytes, alignment_bytes);
172  align.aligned_intermediate = align_up_elems((size_t)graph->config.intermediate_size,
173  elem_bytes, alignment_bytes);
174  align.aligned_context = align_up_elems((size_t)graph->config.context_window, elem_bytes, alignment_bytes);
175 
176  size_t arena_offsets[CK_MEM_ARENA_COUNT] = {0};
177 
178  for (int i = 0; i < graph->num_buffers; ++i) {
179  const CKIRV2Buffer *buf = &graph->buffers[i];
180  CKMemSpan *span = &plan->spans[i];
181  span->buffer_id = i;
182  span->arena = arena_for_role(buf->role);
183  span->offset_bytes = 0;
184  span->size_bytes = 0;
185 
186  if (!buffer_enabled(graph, buf, training_enabled)) {
187  continue;
188  }
189 
190  if (buf->alias_of) {
191  int alias_id = find_buffer_by_name(graph, buf->alias_of);
192  if (alias_id >= 0) {
193  span->offset_bytes = plan->spans[alias_id].offset_bytes;
194  span->size_bytes = plan->spans[alias_id].size_bytes;
195  span->arena = plan->spans[alias_id].arena;
196  continue;
197  }
198  }
199 
200  size_t n_elems = resolve_shape_elems(&graph->config, &align, buf->shape, tokens_override);
201  size_t bytes = ck_dtype_row_bytes(buf->dtype, n_elems);
202  size_t aligned = align_up_bytes(bytes, alignment_bytes);
203 
204  span->offset_bytes = arena_offsets[span->arena];
205  span->size_bytes = bytes;
206  arena_offsets[span->arena] += aligned;
207  }
208 
209  for (int i = 0; i < CK_MEM_ARENA_COUNT; ++i) {
210  plan->total_bytes[i] = arena_offsets[i];
211  }
212 
213  return 0;
214 }
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
static size_t resolve_shape_elems(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, const CKDimToken *shape, int tokens_override)
static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
static int find_buffer_by_name(const CKIRV2Graph *graph, const char *name)
static CKMemArenaKind arena_for_role(CKBufferRole role)
static int buffer_enabled(const CKIRV2Graph *graph, const CKIRV2Buffer *buf, int training_enabled)
@ CK_MEM_ARENA_COUNT
#define CK_MEM_PLAN_DEFAULT_ALIGN
CKDimToken shape[4]
Definition: ckernel_ir_v2.h:29
char * alias_of
Definition: ckernel_ir_v2.h:31
CKDataType dtype
Definition: ckernel_ir_v2.h:28
CKIRV2Buffer * buffers
Definition: ckernel_ir_v2.h:62
CKMemSpan * spans
size_t alignment_bytes
size_t total_bytes[CK_MEM_ARENA_COUNT]
size_t size_bytes
CKMemArenaKind arena
size_t offset_bytes
int context_window
Definition: ckernel_ir.h:30
int intermediate_size
Definition: ck_model_api.h:37

References CKIRV2Buffer::alias_of, align_up_bytes(), align_up_elems(), CKMemPlan::alignment_bytes, CKMemSpan::arena, arena_for_role(), buffer_enabled(), CKMemSpan::buffer_id, CKIRV2Graph::buffers, ck_dtype_row_bytes(), CK_MEM_ARENA_COUNT, CK_MEM_PLAN_DEFAULT_ALIGN, CKIRV2Graph::config, CKModelConfig::context_window, CKIRV2Buffer::dtype, find_buffer_by_name(), CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKIRV2Graph::num_buffers, CKModelConfig::num_heads, CKMemPlan::num_spans, CKMemSpan::offset_bytes, resolve_shape_elems(), CKIRV2Buffer::role, CKIRV2Buffer::shape, CKMemSpan::size_bytes, CKMemPlan::spans, and CKMemPlan::total_bytes.

Referenced by ck_mem_plan_build_inference(), ck_mem_plan_build_inference_with_tokens(), ck_mem_plan_build_training(), and ck_mem_plan_build_training_with_tokens().

◆ ck_mem_plan_build_inference()

int ck_mem_plan_build_inference ( const CKIRV2Graph graph,
CKMemPlan plan,
size_t  alignment_bytes 
)

Definition at line 216 of file ckernel_mem_plan.c.

219 {
220  return build_plan(graph, plan, alignment_bytes, 0, -1);
221 }
static int build_plan(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int training_enabled, int tokens_override)

References build_plan().

◆ ck_mem_plan_build_inference_with_tokens()

int ck_mem_plan_build_inference_with_tokens ( const CKIRV2Graph graph,
CKMemPlan plan,
size_t  alignment_bytes,
int  tokens_override 
)

Definition at line 230 of file ckernel_mem_plan.c.

234 {
235  return build_plan(graph, plan, alignment_bytes, 0, tokens_override);
236 }

References build_plan().

Referenced by ck_codegen_v2_emit_runtime(), and ck_ir_v2_lower_graph().

◆ ck_mem_plan_build_training()

int ck_mem_plan_build_training ( const CKIRV2Graph graph,
CKMemPlan plan,
size_t  alignment_bytes 
)

Definition at line 223 of file ckernel_mem_plan.c.

226 {
227  return build_plan(graph, plan, alignment_bytes, 1, -1);
228 }

References build_plan().

◆ ck_mem_plan_build_training_with_tokens()

int ck_mem_plan_build_training_with_tokens ( const CKIRV2Graph graph,
CKMemPlan plan,
size_t  alignment_bytes,
int  tokens_override 
)

Definition at line 238 of file ckernel_mem_plan.c.

242 {
243  return build_plan(graph, plan, alignment_bytes, 1, tokens_override);
244 }

References build_plan().

Referenced by ck_codegen_v2_emit_runtime(), and ck_ir_v2_lower_graph().

◆ ck_mem_plan_free()

void ck_mem_plan_free ( CKMemPlan plan)

Definition at line 246 of file ckernel_mem_plan.c.

247 {
248  if (!plan) {
249  return;
250  }
251  free(plan->spans);
252  memset(plan, 0, sizeof(*plan));
253 }

References CKMemPlan::spans.

Referenced by ck_codegen_v2_emit_runtime(), ck_ir_v2_lower_emit_json(), and ck_ir_v2_lower_graph().

◆ find_buffer_by_name()

static int find_buffer_by_name ( const CKIRV2Graph graph,
const char *  name 
)
static

Definition at line 133 of file ckernel_mem_plan.c.

134 {
135  if (!graph || !name) {
136  return -1;
137  }
138  for (int i = 0; i < graph->num_buffers; ++i) {
139  if (graph->buffers[i].name && strcmp(graph->buffers[i].name, name) == 0) {
140  return i;
141  }
142  }
143  return -1;
144 }

References CKIRV2Graph::buffers, CKIRV2Buffer::name, and CKIRV2Graph::num_buffers.

Referenced by build_plan().

◆ resolve_dim()

static size_t resolve_dim ( const CKModelConfig cfg,
const CKIRV2AlignInfo *  align,
CKDimKind  kind,
int  tokens_override 
)
static

Definition at line 28 of file ckernel_mem_plan.c.

32 {
33  switch (kind) {
34  case CK_DIM_TOKENS:
35  if (tokens_override >= 0) {
36  return (size_t)tokens_override;
37  }
38  return (size_t)cfg->context_window;
39  case CK_DIM_EMBED:
40  return (size_t)cfg->hidden_size;
42  return align->aligned_embed;
43  case CK_DIM_HEAD_DIM:
44  return (size_t)(cfg->hidden_size / cfg->num_heads);
46  return align->aligned_head;
47  case CK_DIM_NUM_HEADS:
48  return (size_t)cfg->num_heads;
50  return (size_t)cfg->num_kv_heads;
51  case CK_DIM_ALIGNED_CTX:
52  return align->aligned_context;
54  return (size_t)cfg->intermediate_size;
56  return align->aligned_intermediate;
57  case CK_DIM_VOCAB:
58  return (size_t)cfg->vocab_size;
59  case CK_DIM_END:
60  default:
61  return 0;
62  }
63 }
@ CK_DIM_ALIGNED_INTERMEDIATE
@ CK_DIM_NUM_HEADS
@ CK_DIM_ALIGNED_EMBED
@ CK_DIM_TOKENS
@ CK_DIM_INTERMEDIATE
@ CK_DIM_ALIGNED_CTX
@ CK_DIM_END
@ CK_DIM_ALIGNED_HEAD
@ CK_DIM_HEAD_DIM
@ CK_DIM_NUM_KV_HEADS
@ CK_DIM_VOCAB
@ CK_DIM_EMBED

References CK_DIM_ALIGNED_CTX, CK_DIM_ALIGNED_EMBED, CK_DIM_ALIGNED_HEAD, CK_DIM_ALIGNED_INTERMEDIATE, CK_DIM_EMBED, CK_DIM_END, CK_DIM_HEAD_DIM, CK_DIM_INTERMEDIATE, CK_DIM_NUM_HEADS, CK_DIM_NUM_KV_HEADS, CK_DIM_TOKENS, CK_DIM_VOCAB, CKModelConfig::context_window, CKModelConfig::hidden_size, CKModelConfig::intermediate_size, CKModelConfig::num_heads, CKModelConfig::num_kv_heads, and CKModelConfig::vocab_size.

Referenced by resolve_shape_elems().

◆ resolve_shape_elems()

static size_t resolve_shape_elems ( const CKModelConfig cfg,
const CKIRV2AlignInfo *  align,
const CKDimToken shape,
int  tokens_override 
)
static

Definition at line 65 of file ckernel_mem_plan.c.

69 {
70  size_t total = 1;
71  for (int i = 0; i < CK_IR_V2_MAX_DIMS; ++i) {
72  if (shape[i].dim == CK_DIM_END) {
73  break;
74  }
75  size_t dim = resolve_dim(cfg, align, shape[i].dim, tokens_override);
76  size_t mult = (size_t)(shape[i].mult > 0 ? shape[i].mult : 1);
77  size_t div = (size_t)(shape[i].div > 0 ? shape[i].div : 1);
78  if (div == 0) div = 1;
79  total = total * dim * mult / div;
80  }
81  return total;
82 }
#define CK_IR_V2_MAX_DIMS
Definition: ckernel_ir_v2.h:13
static size_t resolve_dim(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, CKDimKind kind, int tokens_override)

References CK_DIM_END, CK_IR_V2_MAX_DIMS, and resolve_dim().

Referenced by build_plan().