← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_orchestration.h
Go to the documentation of this file.
1 /**
2  * @file ckernel_orchestration.h
3  *
4  * ===========================================================================
5  * LEGACY HEADER - NOT USED IN v6.6
6  * ===========================================================================
7  *
8  * This header declares v6.5 orchestration functions that are NO LONGER USED.
9  * v6.6 uses IR Lower 3 + codegen instead of hardcoded orchestration.
10  *
11  * v6.6 Architecture (REPLACEMENT):
12  * - Kernel dispatch: version/v6.6/scripts/build_ir_v6_6.py + ckernel_codegen.c
13  * - Memory planning: version/v6.6/scripts/memory_planner_v6_6.py
14  * - Registry: version/v6.6/kernel_maps/KERNEL_REGISTRY.json
15  * - Kernel bindings: version/v6.6/kernel_maps/kernel_bindings.json
16  *
17  * Deprecated functions (NOT used in v6.6):
18  * - ck_layer_forward_rmsnorm_swiglu* -> IR Lower 3 + mega_fused_* kernels
19  * - ck_qkv_project_head_major* -> q_proj/k_proj/v_proj ops in IR
20  * - ck_attention_project_head_major* -> out_proj op in IR
21  * - ck_mlp_swiglu_forward* -> mlp_gate_up/mlp_down ops in IR
22  * - ck_gemm_nt_quant -> KERNEL_REGISTRY.json dispatch
23  * - ck_residual_add_token_major -> residual_add op in IR
24  *
25  * To remove completely:
26  * 1. Delete this header
27  * 2. Delete ckernel_orchestration.c
28  * 3. Remove from Makefile SRCS list
29  *
30  * Last used: v6.5
31  * Deprecated: v6.6 (2026-02)
32  * ===========================================================================
33  */
34 
35 #ifndef CKERNEL_ORCHESTRATION_H
36 #define CKERNEL_ORCHESTRATION_H
37 
38 #include <stddef.h>
39 #include "ckernel_dtype.h"
40 
41 #ifdef __cplusplus
42 extern "C" {
43 #endif
44 
45 typedef struct {
46  int tokens;
47  int embed_dim;
49  int num_heads;
51  int head_dim;
56  float eps;
58 
59  const float *input; /* [T x aligned_embed_dim] */
60  const float *ln1_gamma; /* [aligned_embed_dim] */
61  const float *ln2_gamma; /* [aligned_embed_dim] */
62 
63  const float *rope_cos; /* [max_seq_len x head_dim/2] */
64  const float *rope_sin; /* [max_seq_len x head_dim/2] */
65 
66  const float *wq; /* [num_heads x aligned_head_dim x aligned_embed_dim] */
67  const float *bq; /* [num_heads x aligned_head_dim] */
68  const float *wk; /* [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
69  const float *bk; /* [num_kv_heads x aligned_head_dim] */
70  const float *wv; /* [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
71  const float *bv; /* [num_kv_heads x aligned_head_dim] */
72 
73  const float *wo; /* [H x aligned_embed_dim x aligned_head_dim] */
74  const float *bo; /* [aligned_embed_dim] */
75 
76  const float *w1; /* [2*aligned_intermediate_dim x aligned_embed_dim] */
77  const float *b1; /* [2*aligned_intermediate_dim] */
78  const float *w2; /* [aligned_embed_dim x aligned_intermediate_dim] */
79  const float *b2; /* [aligned_embed_dim] */
80 
81  float *ln1_out; /* [T x aligned_embed_dim] */
82  float *ln1_rstd; /* [T] (optional) */
83  float *q; /* [num_heads x T x aligned_head_dim] */
84  float *k; /* [num_kv_heads x T x aligned_head_dim] */
85  float *v; /* [num_kv_heads x T x aligned_head_dim] */
86  float *scores; /* [num_heads x aligned_context_window x aligned_context_window] */
87  float *attn_out; /* [num_heads x T x aligned_head_dim] */
88  float *proj_tmp; /* [T x aligned_embed_dim] */
89  float *proj_scratch; /* [T x aligned_embed_dim], required if num_heads > 1 */
90  float *residual1; /* [T x aligned_embed_dim] */
91  float *ln2_out; /* [T x aligned_embed_dim] */
92  float *ln2_rstd; /* [T] (optional) */
93  float *fc1_out; /* [T x 2*aligned_intermediate_dim] */
94  float *swiglu_out;/* [T x aligned_intermediate_dim] */
95  float *mlp_out; /* [T x aligned_embed_dim] */
96  float *output; /* [T x aligned_embed_dim] */
98 
99 typedef struct {
100  int tokens;
105  int head_dim;
110  float eps;
112 
113  const float *input; /* [T x aligned_embed_dim] */
114  const float *ln1_gamma; /* [aligned_embed_dim] */
115  const float *ln2_gamma; /* [aligned_embed_dim] */
116  const float *ln1_out; /* [T x aligned_embed_dim] */
117  const float *ln1_rstd; /* [T] */
118  const float *ln2_out; /* [T x aligned_embed_dim] */
119  const float *ln2_rstd; /* [T] */
120 
121  const float *rope_cos; /* [max_seq_len x head_dim/2] */
122  const float *rope_sin; /* [max_seq_len x head_dim/2] */
123 
124  const float *wq; /* [num_heads x aligned_head_dim x aligned_embed_dim] */
125  const float *bq; /* [num_heads x aligned_head_dim] */
126  const float *wk; /* [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
127  const float *bk; /* [num_kv_heads x aligned_head_dim] */
128  const float *wv; /* [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
129  const float *bv; /* [num_kv_heads x aligned_head_dim] */
130 
131  const float *wo; /* [H x aligned_embed_dim x aligned_head_dim] */
132  const float *bo; /* [aligned_embed_dim] */
133 
134  const float *w1; /* [2*aligned_intermediate_dim x aligned_embed_dim] */
135  const float *b1; /* [2*aligned_intermediate_dim] */
136  const float *w2; /* [aligned_embed_dim x aligned_intermediate_dim] */
137  const float *b2; /* [aligned_embed_dim] */
138 
139  const float *q; /* [num_heads x T x aligned_head_dim] */
140  const float *k; /* [num_kv_heads x T x aligned_head_dim] */
141  const float *v; /* [num_kv_heads x T x aligned_head_dim] */
142  const float *scores; /* [num_heads x aligned_context_window x aligned_context_window] */
143  const float *attn_out; /* [num_heads x T x aligned_head_dim] */
144  const float *residual1; /* [T x aligned_embed_dim] */
145  const float *fc1_out; /* [T x 2*aligned_intermediate_dim] */
146  const float *swiglu_out;/* [T x aligned_intermediate_dim] */
147 
148  float *d_output; /* [T x aligned_embed_dim] */
149  float *d_input; /* [T x aligned_embed_dim] */
150  float *d_ln1_gamma; /* [aligned_embed_dim] */
151  float *d_ln2_gamma; /* [aligned_embed_dim] */
152  float *d_wq; /* [num_heads x aligned_head_dim x aligned_embed_dim] */
153  float *d_bq; /* [num_heads x aligned_head_dim] */
154  float *d_wk; /* [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
155  float *d_bk; /* [num_kv_heads x aligned_head_dim] */
156  float *d_wv; /* [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
157  float *d_bv; /* [num_kv_heads x aligned_head_dim] */
158  float *d_wo; /* [H x aligned_embed_dim x aligned_head_dim] */
159  float *d_bo; /* [aligned_embed_dim] */
160  float *d_w1; /* [2*aligned_intermediate_dim x aligned_embed_dim] */
161  float *d_b1; /* [2*aligned_intermediate_dim] */
162  float *d_w2; /* [aligned_embed_dim x aligned_intermediate_dim] */
163  float *d_b2; /* [aligned_embed_dim] */
164 
165  float *d_ln1_out; /* [T x aligned_embed_dim] */
166  float *d_q; /* [num_heads x T x aligned_head_dim] */
167  float *d_k; /* [num_kv_heads x T x aligned_head_dim] */
168  float *d_v; /* [num_kv_heads x T x aligned_head_dim] */
169  float *d_scores; /* [num_heads x aligned_context_window x aligned_context_window] */
170  float *d_attn_out; /* [num_heads x T x aligned_head_dim] */
171  float *d_proj_tmp; /* [T x aligned_embed_dim] */
172  float *d_residual1; /* [T x aligned_embed_dim] */
173  float *d_ln2_out; /* [T x aligned_embed_dim] */
174  float *d_fc1_out; /* [T x 2*aligned_intermediate_dim] */
175  float *d_swiglu_out; /* [T x aligned_intermediate_dim] */
176  float *d_mlp_out; /* [T x aligned_embed_dim] */
178 
179 void ck_residual_add_token_major(const float *a,
180  const float *b,
181  float *out,
182  int tokens,
183  int aligned_embed_dim);
184 
185 /* Generic quantized GEMM dispatcher (NT layout) */
186 void ck_gemm_nt_quant(const float *A,
187  const void *B,
188  const float *bias,
189  float *C,
190  int M, int N, int K,
191  CKDataType dtype);
192 
193 void ck_qkv_project_head_major(const float *input,
194  const float *wq, const float *bq,
195  const float *wk, const float *bk,
196  const float *wv, const float *bv,
197  float *q, float *k, float *v,
198  int tokens,
199  int kv_stride_tokens,
200  int aligned_embed_dim,
201  int num_heads,
202  int num_kv_heads,
203  int aligned_head_dim);
204 
205 void ck_qkv_project_head_major_token(const float *input_row,
206  const float *wq, const float *bq,
207  const float *wk, const float *bk,
208  const float *wv, const float *bv,
209  float *q_token,
210  float *k_token,
211  float *v_token,
212  int aligned_embed_dim,
213  int num_heads,
214  int num_kv_heads,
215  int aligned_head_dim);
216 
217 void ck_attention_project_head_major(const float *attn_out,
218  const float *wo,
219  const float *bo,
220  float *out,
221  float *scratch,
222  int tokens,
223  int aligned_embed_dim,
224  int num_heads,
225  int aligned_head_dim);
226 
227 void ck_attention_project_head_major_decode_token(const float *attn_token,
228  const float *wo,
229  const float *bo,
230  float *out_token,
231  int embed_dim,
232  int aligned_embed_dim,
233  int num_heads,
234  int aligned_head_dim);
235 
236 void ck_mlp_swiglu_forward(const float *input,
237  const float *w1,
238  const float *b1,
239  const float *w2,
240  const float *b2,
241  float *fc1_out,
242  float *swiglu_out,
243  float *output,
244  int tokens,
245  int aligned_embed_dim,
246  int aligned_intermediate_dim);
247 
248 void ck_mlp_swiglu_forward_fused_token(const float *input_row,
249  const float *w1,
250  const float *b1,
251  const float *w2,
252  const float *b2,
253  float *swiglu_row,
254  float *output_row,
255  int aligned_embed_dim,
256  int aligned_intermediate_dim);
257 
258 // Fully fused MLP for decode (single token).
259 // All three projections (gate, up, down) fused into one kernel.
260 // Eliminates DRAM round-trip for intermediate swiglu values.
261 // Best for AVX-512 systems with many cores (24+).
262 void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row,
263  const float *w1,
264  const float *b1,
265  const float *w2,
266  const float *b2,
267  float *output_row,
268  int aligned_embed_dim,
269  int aligned_intermediate_dim);
270 
273 
274 // Decode-style layer forward for autoregressive generation.
275 //
276 // Computes only a single token at `token_index`, while attending over the
277 // KV-cache stored in `p->k`/`p->v` in head-major cache layout:
278 // k/v: [num_kv_heads, cache_capacity, aligned_head_dim]
279 //
280 // The caller is responsible for:
281 // - ensuring `p->k`/`p->v` already contain tokens [0..token_index-1]
282 // - setting `p->rope_pos_offset` to the absolute position for this token
283 // - passing a matching `cache_capacity` (usually model context_window)
285  int token_index,
286  int cache_capacity);
287 
288 // Decode-style layer forward using fused SwiGLU (gate+up) matvec.
289 // Inference-only fast path: produces the same outputs as the unfused decode path.
291  int token_index,
292  int cache_capacity);
293 
294 // Decode-style layer forward using fused attention (QKV+RoPE+KV+attention+Wo).
295 // Optionally pairs with fused SwiGLU via ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp.
297  int token_index,
298  int cache_capacity);
299 
301  int token_index,
302  int cache_capacity);
303 
304 /* ============================================================================
305  * Quantized (Q4_K / Q4_K_M) inference orchestration
306  *
307  * These entry points mirror the fp32 paths but accept weight matrices stored
308  * in GGML-compatible Q4_K blocks. Activations remain fp32 by default; the
309  * decode path can switch to Q8_K activations via CK_Q8K_ACTIVATIONS=1 (or
310  * auto-enable when strict parity is off).
311  *
312  * Design note:
313  * - If you enable Q4_K weights, ensure the relevant K dimensions are a
314  * multiple of 256 (QK_K). The engine keeps the quantized weights in their
315  * compact block form and dequantizes on-the-fly inside GEMM/GEMV kernels.
316  * ============================================================================ */
317 
318 typedef struct {
319  int tokens;
324  int head_dim;
329  float eps;
331 
332  const float *input; /* [T x aligned_embed_dim] */
333  const float *ln1_gamma; /* [aligned_embed_dim] */
334  const float *ln2_gamma; /* [aligned_embed_dim] */
335 
336  const float *rope_cos; /* [max_seq_len x head_dim/2] */
337  const float *rope_sin; /* [max_seq_len x head_dim/2] */
338 
339  const void *wq; /* Q4_K: [num_heads x aligned_head_dim x aligned_embed_dim] */
340  const float *bq; /* [num_heads x aligned_head_dim] */
341  const void *wk; /* Q4_K: [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
342  const float *bk; /* [num_kv_heads x aligned_head_dim] */
343  const void *wv; /* Q4_K: [num_kv_heads x aligned_head_dim x aligned_embed_dim] */
344  const float *bv; /* [num_kv_heads x aligned_head_dim] */
345 
346  const void *wo; /* Q4_K: [aligned_embed_dim x (num_heads*aligned_head_dim)] */
347  const float *bo; /* [aligned_embed_dim] */
348 
349  const void *w1; /* Q4_K: [2*aligned_intermediate_dim x aligned_embed_dim] */
350  const float *b1; /* [2*aligned_intermediate_dim] */
351  const void *w2; /* Q4_K: [aligned_embed_dim x aligned_intermediate_dim] */
352  const float *b2; /* [aligned_embed_dim] */
353 
354  float *ln1_out; /* [T x aligned_embed_dim] */
355  float *ln1_rstd; /* [T] (optional) */
356  float *q; /* [num_heads x T x aligned_head_dim] */
357  float *k; /* [num_kv_heads x T x aligned_head_dim] */
358  float *v; /* [num_kv_heads x T x aligned_head_dim] */
359  float *scores; /* [num_heads x aligned_context_window x aligned_context_window] */
360  float *attn_out; /* [num_heads x T x aligned_head_dim] */
361  float *proj_tmp; /* [T x aligned_embed_dim] */
362  float *proj_scratch; /* [T x aligned_embed_dim], required (transpose buffer) */
363  float *residual1; /* [T x aligned_embed_dim] */
364  float *ln2_out; /* [T x aligned_embed_dim] */
365  float *ln2_rstd; /* [T] (optional) */
366  float *fc1_out; /* [T x 2*aligned_intermediate_dim] */
367  float *swiglu_out;/* [T x aligned_intermediate_dim] */
368  float *mlp_out; /* [T x aligned_embed_dim] */
369  float *output; /* [T x aligned_embed_dim] */
370 
378 
381  int token_index,
382  int cache_capacity);
385  int token_index,
386  int cache_capacity);
387 
388 void ck_residual_add_backward(const float *d_out,
389  float *d_a,
390  float *d_b,
391  int tokens,
392  int aligned_embed_dim);
393 
394 void ck_attention_project_head_major_backward(const float *d_out,
395  const float *attn_out,
396  const float *wo,
397  float *d_attn_out,
398  float *d_wo,
399  float *d_bo,
400  int tokens,
401  int aligned_embed_dim,
402  int num_heads,
403  int aligned_head_dim);
404 
405 void ck_qkv_project_head_major_backward(const float *d_q,
406  const float *d_k,
407  const float *d_v,
408  const float *input,
409  const float *wq,
410  const float *bq,
411  const float *wk,
412  const float *bk,
413  const float *wv,
414  const float *bv,
415  float *d_input,
416  float *d_wq,
417  float *d_bq,
418  float *d_wk,
419  float *d_bk,
420  float *d_wv,
421  float *d_bv,
422  float *scratch,
423  int tokens,
424  int aligned_embed_dim,
425  int num_heads,
426  int num_kv_heads,
427  int aligned_head_dim,
428  int num_threads);
429 
431 
432 #ifdef __cplusplus
433 } // extern "C"
434 #endif
435 
436 #endif /* CKERNEL_ORCHESTRATION_H */
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
void ck_layer_backward_rmsnorm_swiglu(const CKLayerBackwardParams *p)
void ck_mlp_swiglu_forward_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_attention_project_head_major(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_residual_add_backward(const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
void ck_layer_forward_rmsnorm_swiglu_quant(const CKLayerForwardParamsQ4K *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_layer_forward_rmsnorm_swiglu_decode(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_q4_k(const CKLayerForwardParamsQ4K *p)
void ck_qkv_project_head_major(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_qkv_project_head_major_backward(const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu(const CKLayerForwardParams *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_ref(const CKLayerForwardParams *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_attention_project_head_major_backward(const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu_decode_quant(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
#define C(color)
Definition: show_config.c:39