← Back to C-Kernel-Engine Docs Doxygen Source Documentation
parallel_orchestration.c
Go to the documentation of this file.
1 /**
2  * @file parallel_orchestration.c
3  * @brief [LEGACY] Parallel decode orchestration prototype — NOT USED by v6.6
4  *
5  * This file was an early prototype demonstrating llama.cpp-style OpenMP
6  * parallelization patterns. It is NOT compiled into the v6.6 build and
7  * has no callers in the generated inference code path.
8  *
9  * v6.6 decode runs entirely through the generated code in:
10  * version/v6.6/src/generated/ck-kernel-inference.c
11  * → ck_model_decode_internal()
12  *
13  * Threading for v6.6 is handled by ck_threadpool (include/ck_threadpool.h),
14  * which replaces the OpenMP approach used here.
15  *
16  * Kept for reference only. See the original design notes below.
17  *
18  * Original design (OpenMP, superseded):
19  * - OpenMP parallel region at orchestration level
20  * - Each kernel receives (ith, nth) and processes its slice
21  * - Barriers between dependent operations
22  * - Key insight: amortize thread pool overhead over entire forward pass
23  */
24 
25 #include <omp.h>
26 #include <string.h>
27 #include <math.h>
28 
29 #include "ckernel_engine.h"
30 #include "ckernel_quant.h"
31 
32 /* ============================================================================
33  * PARALLEL KERNEL WRAPPERS
34  *
35  * These call the _parallel_simd versions with thread indices from OpenMP.
36  * Each wrapper receives (ith, nth) from the calling parallel region.
37  * ============================================================================ */
38 
39 /**
40  * Single-token decode with parallel SIMD kernels.
41  *
42  * This is the main decode function that processes one token through all layers.
43  * OpenMP parallel region is created ONCE at the top, and all kernels
44  * receive (ith, nth) to split their work.
45  *
46  * Pattern:
47  * #pragma omp parallel
48  * {
49  * int ith = omp_get_thread_num();
50  * int nth = omp_get_num_threads();
51  *
52  * // Each kernel processes only its slice
53  * gemv_q4_k_q8_k_parallel_simd(..., ith, nth);
54  * #pragma omp barrier
55  *
56  * rmsnorm_parallel(..., ith, nth); // (not implemented yet)
57  * #pragma omp barrier
58  * ...
59  * }
60  */
61 
62 /* Parallel residual add: out = a + b, split across threads */
63 static void residual_add_parallel(const float *a, const float *b,
64  float *out, int n,
65  int ith, int nth)
66 {
67  const int dr = (n + nth - 1) / nth;
68  const int r0 = dr * ith;
69  const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
70 
71  if (r0 >= n) return;
72 
73  for (int i = r0; i < r1; i++) {
74  out[i] = a[i] + b[i];
75  }
76 }
77 
78 /* Parallel scale: y[i] *= scale, split across threads */
79 static void vec_scale_parallel(float *y, float scale, int n,
80  int ith, int nth)
81 {
82  const int dr = (n + nth - 1) / nth;
83  const int r0 = dr * ith;
84  const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
85 
86  if (r0 >= n) return;
87 
88  for (int i = r0; i < r1; i++) {
89  y[i] *= scale;
90  }
91 }
92 
93 /* Parallel zero: memset to 0, split across threads */
94 static void vec_zero_parallel(float *y, int n, int ith, int nth)
95 {
96  const int dr = (n + nth - 1) / nth;
97  const int r0 = dr * ith;
98  const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
99 
100  if (r0 >= n) return;
101 
102  memset(y + r0, 0, (r1 - r0) * sizeof(float));
103 }
104 
105 /* ============================================================================
106  * EXAMPLE: Parallel Q/K/V Projection
107  *
108  * This shows how to parallelize the QKV projections in one OpenMP region.
109  * In practice, this would be integrated into the full decode function.
110  * ============================================================================ */
111 
112 /**
113  * Parallel Q/K/V projection for single token decode.
114  *
115  * @param ln1_q8 Input: RMSNorm output quantized to Q8_K [aligned_embed]
116  * @param WQ Q weights [H*head_dim, aligned_embed] in Q4_K
117  * @param WK K weights [H_kv*head_dim, aligned_embed] in Q4_K
118  * @param WV V weights [H_kv*head_dim, aligned_embed] in Q4_K
119  * @param q_out Output: Q vectors [H, head_dim]
120  * @param k_out Output: K vector [H_kv, head_dim]
121  * @param v_out Output: V vector [H_kv, head_dim]
122  * @param H Number of query heads
123  * @param H_kv Number of KV heads (GQA)
124  * @param head_dim Head dimension
125  * @param embed_dim Embedding dimension
126  * @param num_threads Number of threads to use (0 = auto)
127  */
128 void qkv_projection_parallel(const void *ln1_q8,
129  const void *WQ,
130  const void *WK,
131  const void *WV,
132  float *q_out,
133  float *k_out,
134  float *v_out,
135  int H, int H_kv,
136  int head_dim, int embed_dim,
137  int num_threads)
138 {
139  const int q_dim = H * head_dim;
140  const int kv_dim = H_kv * head_dim;
141 
142  /* Align to QK_K for quantized matmul */
143  const int aligned_embed = ((embed_dim + 255) / 256) * 256;
144 
145  if (num_threads <= 0) {
146  num_threads = omp_get_max_threads();
147  }
148 
149  /* Single OpenMP region for all three projections */
150  #pragma omp parallel num_threads(num_threads)
151  {
152  const int ith = omp_get_thread_num();
153  const int nth = omp_get_num_threads();
154 
155  /* Q projection: largest, benefits most from parallelism */
156  gemv_q4_k_q8_k_parallel_simd(q_out, WQ, ln1_q8, q_dim, aligned_embed, ith, nth);
157 
158  /* K projection: smaller, but still benefits */
159  gemv_q4_k_q8_k_parallel_simd(k_out, WK, ln1_q8, kv_dim, aligned_embed, ith, nth);
160 
161  /* V projection */
162  gemv_q4_k_q8_k_parallel_simd(v_out, WV, ln1_q8, kv_dim, aligned_embed, ith, nth);
163 
164  /* Implicit barrier at end of parallel region */
165  }
166 }
167 
168 /**
169  * Parallel MLP (gate/up + SwiGLU + down projection).
170  *
171  * @param ln2_q8 Input: RMSNorm output quantized to Q8_K
172  * @param W_gate Gate weights [intermediate, embed] in Q4_K
173  * @param W_up Up weights [intermediate, embed] in Q4_K
174  * @param W_down Down weights [embed, intermediate] in Q4_K
175  * @param gate_buf Scratch: gate output [intermediate]
176  * @param up_buf Scratch: up output [intermediate]
177  * @param swiglu_buf Scratch: SwiGLU output [intermediate]
178  * @param down_q8 Scratch: down input quantized [intermediate Q8_K blocks]
179  * @param mlp_out Output: MLP output [embed]
180  * @param intermediate Intermediate dimension
181  * @param embed_dim Embedding dimension
182  * @param num_threads Number of threads (0 = auto)
183  */
184 void mlp_parallel(const void *ln2_q8,
185  const void *W_gate,
186  const void *W_up,
187  const void *W_down,
188  float *gate_buf,
189  float *up_buf,
190  float *swiglu_buf,
191  void *down_q8,
192  float *mlp_out,
193  int intermediate,
194  int embed_dim,
195  int num_threads)
196 {
197  const int aligned_embed = ((embed_dim + 255) / 256) * 256;
198  const int aligned_inter = ((intermediate + 255) / 256) * 256;
199 
200  if (num_threads <= 0) {
201  num_threads = omp_get_max_threads();
202  }
203 
204  #pragma omp parallel num_threads(num_threads)
205  {
206  const int ith = omp_get_thread_num();
207  const int nth = omp_get_num_threads();
208 
209  /* Gate and Up projections (can run in parallel, no dependency) */
210  gemv_q4_k_q8_k_parallel_simd(gate_buf, W_gate, ln2_q8, aligned_inter, aligned_embed, ith, nth);
211  gemv_q4_k_q8_k_parallel_simd(up_buf, W_up, ln2_q8, aligned_inter, aligned_embed, ith, nth);
212 
213  #pragma omp barrier /* Wait for gate and up to complete */
214 
215  /* SwiGLU: swiglu = silu(gate) * up */
216  /* This is element-wise, parallelize across elements */
217  const int dr = (aligned_inter + nth - 1) / nth;
218  const int r0 = dr * ith;
219  const int r1 = (r0 + dr < aligned_inter) ? (r0 + dr) : aligned_inter;
220 
221  for (int i = r0; i < r1 && i < intermediate; i++) {
222  float g = gate_buf[i];
223  float silu_g = g / (1.0f + expf(-g)); /* SiLU activation */
224  swiglu_buf[i] = silu_g * up_buf[i];
225  }
226 
227  #pragma omp barrier /* Wait for SwiGLU to complete */
228 
229  /* Down projection: only thread 0 quantizes (single-threaded) */
230  /* TODO: Add parallel quantization */
231  #pragma omp single
232  {
233  quantize_row_q8_k(swiglu_buf, down_q8, aligned_inter);
234  }
235  /* Implicit barrier after omp single */
236 
237  /* Down projection */
238  gemv_q4_k_q8_k_parallel_simd(mlp_out, W_down, down_q8, aligned_embed, aligned_inter, ith, nth);
239  }
240 }
241 
242 /* ============================================================================
243  * FULL LAYER DECODE (parallel)
244  *
245  * Processes one transformer layer with all operations parallelized.
246  * ============================================================================ */
247 
248 /**
249  * Process one transformer layer in parallel.
250  *
251  * This demonstrates the full parallel pattern for a single layer.
252  * In production, this would be called in a loop for all layers.
253  */
255  /* Inputs */
256  float *hidden, /* [embed_dim] - modified in place */
257  const void *ln1_weight, /* RMSNorm weights */
258  const void *ln2_weight, /* RMSNorm weights */
259  const void *WQ, /* Q4_K weights */
260  const void *WK,
261  const void *WV,
262  const void *WO,
263  const void *W_gate,
264  const void *W_up,
265  const void *W_down,
266  /* KV cache */
267  float *k_cache, /* [H_kv, max_seq, head_dim] */
268  float *v_cache,
269  int token_index, /* Current position in sequence */
270  /* Scratch buffers */
271  float *scratch, /* Aligned scratch space */
272  /* Model config */
273  int embed_dim,
274  int intermediate,
275  int H, int H_kv,
276  int head_dim,
277  int max_seq,
278  float eps, /* RMSNorm epsilon */
279  /* Threading */
280  int num_threads)
281 {
282  if (num_threads <= 0) {
283  num_threads = omp_get_max_threads();
284  }
285 
286  /* Align dimensions */
287  const int aligned_embed = ((embed_dim + 255) / 256) * 256;
288  const int aligned_inter = ((intermediate + 255) / 256) * 256;
289  const int aligned_head = ((head_dim + 31) / 32) * 32;
290 
291  /* Partition scratch buffer */
292  float *ln1_out = scratch;
293  float *q_vec = ln1_out + aligned_embed;
294  float *k_vec = q_vec + H * aligned_head;
295  float *v_vec = k_vec + H_kv * aligned_head;
296  float *attn_out = v_vec + H_kv * aligned_head;
297  float *o_out = attn_out + H * aligned_head;
298  float *ln2_out = o_out + aligned_embed;
299  float *gate_buf = ln2_out + aligned_embed;
300  float *up_buf = gate_buf + aligned_inter;
301  float *swiglu_buf = up_buf + aligned_inter;
302  float *mlp_out = swiglu_buf + aligned_inter;
303 
304  /* Q8_K buffers for quantized input */
305  const size_t q8_embed_bytes = ((aligned_embed + 255) / 256) * 292;
306  const size_t q8_inter_bytes = ((aligned_inter + 255) / 256) * 292;
307  uint8_t *ln1_q8 = (uint8_t *)(mlp_out + aligned_embed);
308  uint8_t *ln2_q8 = ln1_q8 + q8_embed_bytes;
309  uint8_t *down_q8 = ln2_q8 + q8_embed_bytes;
310 
311  #pragma omp parallel num_threads(num_threads)
312  {
313  const int ith = omp_get_thread_num();
314  const int nth = omp_get_num_threads();
315 
316  /* ============================================================
317  * ATTENTION BLOCK
318  * ============================================================ */
319 
320  /* Step 1: RMSNorm (TODO: parallelize reduction) */
321  #pragma omp single
322  {
323  rmsnorm(hidden, (const float *)ln1_weight, ln1_out, embed_dim, eps);
324  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed);
325  }
326  /* Implicit barrier after single */
327 
328  /* Step 2: QKV projections (parallel) */
329  gemv_q4_k_q8_k_parallel_simd(q_vec, WQ, ln1_q8, H * head_dim, aligned_embed, ith, nth);
330  gemv_q4_k_q8_k_parallel_simd(k_vec, WK, ln1_q8, H_kv * head_dim, aligned_embed, ith, nth);
331  gemv_q4_k_q8_k_parallel_simd(v_vec, WV, ln1_q8, H_kv * head_dim, aligned_embed, ith, nth);
332 
333  #pragma omp barrier
334 
335  /* Step 3: RoPE, attention, etc. would go here (single-threaded for now) */
336  #pragma omp single
337  {
338  /* Copy K/V to cache */
339  const int kv_head_stride = max_seq * aligned_head;
340  for (int h = 0; h < H_kv; h++) {
341  memcpy(k_cache + h * kv_head_stride + token_index * aligned_head,
342  k_vec + h * head_dim, head_dim * sizeof(float));
343  memcpy(v_cache + h * kv_head_stride + token_index * aligned_head,
344  v_vec + h * head_dim, head_dim * sizeof(float));
345  }
346 
347  /* TODO: RoPE, attention decode, etc. */
348  /* For now, just copy q_vec to attn_out as placeholder */
349  memcpy(attn_out, q_vec, H * head_dim * sizeof(float));
350  }
351 
352  #pragma omp barrier
353 
354  /* Step 4: Output projection (parallel) */
355  /* First quantize attention output */
356  #pragma omp single
357  {
358  quantize_row_q8_k(attn_out, ln1_q8, H * aligned_head); /* Reuse ln1_q8 */
359  }
360 
361  gemv_q4_k_q8_k_parallel_simd(o_out, WO, ln1_q8, aligned_embed, H * head_dim, ith, nth);
362 
363  #pragma omp barrier
364 
365  /* Step 5: Residual add (parallel) */
366  residual_add_parallel(hidden, o_out, hidden, embed_dim, ith, nth);
367 
368  #pragma omp barrier
369 
370  /* ============================================================
371  * MLP BLOCK
372  * ============================================================ */
373 
374  /* Step 6: RMSNorm */
375  #pragma omp single
376  {
377  rmsnorm(hidden, (const float *)ln2_weight, ln2_out, embed_dim, eps);
378  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed);
379  }
380 
381  #pragma omp barrier
382 
383  /* Step 7: Gate and Up projections (parallel) */
384  gemv_q4_k_q8_k_parallel_simd(gate_buf, W_gate, ln2_q8, aligned_inter, aligned_embed, ith, nth);
385  gemv_q4_k_q8_k_parallel_simd(up_buf, W_up, ln2_q8, aligned_inter, aligned_embed, ith, nth);
386 
387  #pragma omp barrier
388 
389  /* Step 8: SwiGLU (parallel element-wise) */
390  const int dr = (intermediate + nth - 1) / nth;
391  const int r0 = dr * ith;
392  const int r1 = (r0 + dr < intermediate) ? (r0 + dr) : intermediate;
393 
394  for (int i = r0; i < r1; i++) {
395  float g = gate_buf[i];
396  float silu_g = g / (1.0f + expf(-g));
397  swiglu_buf[i] = silu_g * up_buf[i];
398  }
399 
400  #pragma omp barrier
401 
402  /* Step 9: Down projection */
403  #pragma omp single
404  {
405  quantize_row_q8_k(swiglu_buf, down_q8, aligned_inter);
406  }
407 
408  gemv_q4_k_q8_k_parallel_simd(mlp_out, W_down, down_q8, aligned_embed, aligned_inter, ith, nth);
409 
410  #pragma omp barrier
411 
412  /* Step 10: Final residual add */
413  residual_add_parallel(hidden, mlp_out, hidden, embed_dim, ith, nth);
414  }
415 }
416 
417 /* ============================================================================
418  * CONFIGURATION
419  * ============================================================================ */
420 
421 /* Optimal thread count for decode on memory-bound systems */
423 {
424  int max_threads = omp_get_max_threads();
425 
426  /* For memory-bound workloads, 4 threads is often optimal.
427  * More threads hit memory bandwidth limits with diminishing returns.
428  * See MEMORY_BANDWIDTH_ANALYSIS.md for details. */
429  if (max_threads >= 4) {
430  return 4;
431  }
432  return max_threads;
433 }
void quantize_row_q8_k(const float *x, void *y, int k)
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Quantization block structures for weight-only quantization.
static void vec_zero_parallel(float *y, int n, int ith, int nth)
static void vec_scale_parallel(float *y, float scale, int n, int ith, int nth)
int get_optimal_decode_threads(void)
void qkv_projection_parallel(const void *ln1_q8, const void *WQ, const void *WK, const void *WV, float *q_out, float *k_out, float *v_out, int H, int H_kv, int head_dim, int embed_dim, int num_threads)
static void residual_add_parallel(const float *a, const float *b, float *out, int n, int ith, int nth)
void mlp_parallel(const void *ln2_q8, const void *W_gate, const void *W_up, const void *W_down, float *gate_buf, float *up_buf, float *swiglu_buf, void *down_q8, float *mlp_out, int intermediate, int embed_dim, int num_threads)
void decode_layer_parallel(float *hidden, const void *ln1_weight, const void *ln2_weight, const void *WQ, const void *WK, const void *WV, const void *WO, const void *W_gate, const void *W_up, const void *W_down, float *k_cache, float *v_cache, int token_index, float *scratch, int embed_dim, int intermediate, int H, int H_kv, int head_dim, int max_seq, float eps, int num_threads)