11 size_t aligned_intermediate;
12 size_t aligned_context;
17 if (align == 0)
return n;
18 return (n + align - 1) & ~(align - 1);
21 static size_t align_up_elems(
size_t elems,
size_t elem_bytes,
size_t align_bytes)
23 size_t bytes = elems * elem_bytes;
25 return bytes / elem_bytes;
29 const CKIRV2AlignInfo *align,
35 if (tokens_override >= 0) {
36 return (
size_t)tokens_override;
42 return align->aligned_embed;
46 return align->aligned_head;
52 return align->aligned_context;
56 return align->aligned_intermediate;
66 const CKIRV2AlignInfo *align,
75 size_t dim =
resolve_dim(cfg, align, shape[i].dim, tokens_override);
76 size_t mult = (size_t)(shape[i].mult > 0 ? shape[i].mult : 1);
77 size_t div = (size_t)(shape[i].div > 0 ? shape[i].div : 1);
78 if (div == 0) div = 1;
79 total = total * dim * mult / div;
102 int training_enabled)
104 if (!graph || !buf) {
123 if (!training_enabled) {
135 if (!graph || !name) {
148 size_t alignment_bytes,
149 int training_enabled,
152 if (!graph || !plan) {
155 memset(plan, 0,
sizeof(*plan));
156 if (alignment_bytes == 0) {
167 size_t elem_bytes =
sizeof(float);
168 CKIRV2AlignInfo align = {0};
171 elem_bytes, alignment_bytes);
173 elem_bytes, alignment_bytes);
206 arena_offsets[span->
arena] += aligned;
218 size_t alignment_bytes)
220 return build_plan(graph, plan, alignment_bytes, 0, -1);
225 size_t alignment_bytes)
227 return build_plan(graph, plan, alignment_bytes, 1, -1);
232 size_t alignment_bytes,
235 return build_plan(graph, plan, alignment_bytes, 0, tokens_override);
240 size_t alignment_bytes,
243 return build_plan(graph, plan, alignment_bytes, 1, tokens_override);
252 memset(plan, 0,
sizeof(*plan));
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
#define CK_IR_V2_MAX_DIMS
@ CK_DIM_ALIGNED_INTERMEDIATE
static size_t resolve_dim(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, CKDimKind kind, int tokens_override)
int ck_mem_plan_build_training(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes)
static size_t align_up_bytes(size_t n, size_t align)
static size_t resolve_shape_elems(const CKModelConfig *cfg, const CKIRV2AlignInfo *align, const CKDimToken *shape, int tokens_override)
static size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
static int build_plan(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int training_enabled, int tokens_override)
int ck_mem_plan_build_training_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
static int find_buffer_by_name(const CKIRV2Graph *graph, const char *name)
static CKMemArenaKind arena_for_role(CKBufferRole role)
void ck_mem_plan_free(CKMemPlan *plan)
int ck_mem_plan_build_inference(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes)
int ck_mem_plan_build_inference_with_tokens(const CKIRV2Graph *graph, CKMemPlan *plan, size_t alignment_bytes, int tokens_override)
static int buffer_enabled(const CKIRV2Graph *graph, const CKIRV2Buffer *buf, int training_enabled)
@ CK_MEM_ARENA_ACTIVATIONS
#define CK_MEM_PLAN_DEFAULT_ALIGN
size_t total_bytes[CK_MEM_ARENA_COUNT]