← Back to C-Kernel-Engine Docs Doxygen Source Documentation
parallel_orchestration.c File Reference

[LEGACY] Parallel decode orchestration prototype — NOT USED by v6.6 More...

#include <omp.h>
#include <string.h>
#include <math.h>
#include "ckernel_engine.h"
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

void decode_layer_parallel (float *hidden, const void *ln1_weight, const void *ln2_weight, const void *WQ, const void *WK, const void *WV, const void *WO, const void *W_gate, const void *W_up, const void *W_down, float *k_cache, float *v_cache, int token_index, float *scratch, int embed_dim, int intermediate, int H, int H_kv, int head_dim, int max_seq, float eps, int num_threads)
 
int get_optimal_decode_threads (void)
 
void mlp_parallel (const void *ln2_q8, const void *W_gate, const void *W_up, const void *W_down, float *gate_buf, float *up_buf, float *swiglu_buf, void *down_q8, float *mlp_out, int intermediate, int embed_dim, int num_threads)
 
void qkv_projection_parallel (const void *ln1_q8, const void *WQ, const void *WK, const void *WV, float *q_out, float *k_out, float *v_out, int H, int H_kv, int head_dim, int embed_dim, int num_threads)
 
static void residual_add_parallel (const float *a, const float *b, float *out, int n, int ith, int nth)
 
static void vec_scale_parallel (float *y, float scale, int n, int ith, int nth)
 
static void vec_zero_parallel (float *y, int n, int ith, int nth)
 

Detailed Description

[LEGACY] Parallel decode orchestration prototype — NOT USED by v6.6

This file was an early prototype demonstrating llama.cpp-style OpenMP parallelization patterns. It is NOT compiled into the v6.6 build and has no callers in the generated inference code path.

v6.6 decode runs entirely through the generated code in: version/v6.6/src/generated/ck-kernel-inference.c → ck_model_decode_internal()

Threading for v6.6 is handled by ck_threadpool (include/ck_threadpool.h), which replaces the OpenMP approach used here.

Kept for reference only. See the original design notes below.

Original design (OpenMP, superseded):

  • OpenMP parallel region at orchestration level
  • Each kernel receives (ith, nth) and processes its slice
  • Barriers between dependent operations
  • Key insight: amortize thread pool overhead over entire forward pass

Definition in file parallel_orchestration.c.

Function Documentation

◆ decode_layer_parallel()

void decode_layer_parallel ( float *  hidden,
const void *  ln1_weight,
const void *  ln2_weight,
const void *  WQ,
const void *  WK,
const void *  WV,
const void *  WO,
const void *  W_gate,
const void *  W_up,
const void *  W_down,
float *  k_cache,
float *  v_cache,
int  token_index,
float *  scratch,
int  embed_dim,
int  intermediate,
int  H,
int  H_kv,
int  head_dim,
int  max_seq,
float  eps,
int  num_threads 
)

Process one transformer layer in parallel.

This demonstrates the full parallel pattern for a single layer. In production, this would be called in a loop for all layers.

Definition at line 254 of file parallel_orchestration.c.

281 {
282  if (num_threads <= 0) {
283  num_threads = omp_get_max_threads();
284  }
285 
286  /* Align dimensions */
287  const int aligned_embed = ((embed_dim + 255) / 256) * 256;
288  const int aligned_inter = ((intermediate + 255) / 256) * 256;
289  const int aligned_head = ((head_dim + 31) / 32) * 32;
290 
291  /* Partition scratch buffer */
292  float *ln1_out = scratch;
293  float *q_vec = ln1_out + aligned_embed;
294  float *k_vec = q_vec + H * aligned_head;
295  float *v_vec = k_vec + H_kv * aligned_head;
296  float *attn_out = v_vec + H_kv * aligned_head;
297  float *o_out = attn_out + H * aligned_head;
298  float *ln2_out = o_out + aligned_embed;
299  float *gate_buf = ln2_out + aligned_embed;
300  float *up_buf = gate_buf + aligned_inter;
301  float *swiglu_buf = up_buf + aligned_inter;
302  float *mlp_out = swiglu_buf + aligned_inter;
303 
304  /* Q8_K buffers for quantized input */
305  const size_t q8_embed_bytes = ((aligned_embed + 255) / 256) * 292;
306  const size_t q8_inter_bytes = ((aligned_inter + 255) / 256) * 292;
307  uint8_t *ln1_q8 = (uint8_t *)(mlp_out + aligned_embed);
308  uint8_t *ln2_q8 = ln1_q8 + q8_embed_bytes;
309  uint8_t *down_q8 = ln2_q8 + q8_embed_bytes;
310 
311  #pragma omp parallel num_threads(num_threads)
312  {
313  const int ith = omp_get_thread_num();
314  const int nth = omp_get_num_threads();
315 
316  /* ============================================================
317  * ATTENTION BLOCK
318  * ============================================================ */
319 
320  /* Step 1: RMSNorm (TODO: parallelize reduction) */
321  #pragma omp single
322  {
323  rmsnorm(hidden, (const float *)ln1_weight, ln1_out, embed_dim, eps);
324  quantize_row_q8_k(ln1_out, ln1_q8, aligned_embed);
325  }
326  /* Implicit barrier after single */
327 
328  /* Step 2: QKV projections (parallel) */
329  gemv_q4_k_q8_k_parallel_simd(q_vec, WQ, ln1_q8, H * head_dim, aligned_embed, ith, nth);
330  gemv_q4_k_q8_k_parallel_simd(k_vec, WK, ln1_q8, H_kv * head_dim, aligned_embed, ith, nth);
331  gemv_q4_k_q8_k_parallel_simd(v_vec, WV, ln1_q8, H_kv * head_dim, aligned_embed, ith, nth);
332 
333  #pragma omp barrier
334 
335  /* Step 3: RoPE, attention, etc. would go here (single-threaded for now) */
336  #pragma omp single
337  {
338  /* Copy K/V to cache */
339  const int kv_head_stride = max_seq * aligned_head;
340  for (int h = 0; h < H_kv; h++) {
341  memcpy(k_cache + h * kv_head_stride + token_index * aligned_head,
342  k_vec + h * head_dim, head_dim * sizeof(float));
343  memcpy(v_cache + h * kv_head_stride + token_index * aligned_head,
344  v_vec + h * head_dim, head_dim * sizeof(float));
345  }
346 
347  /* TODO: RoPE, attention decode, etc. */
348  /* For now, just copy q_vec to attn_out as placeholder */
349  memcpy(attn_out, q_vec, H * head_dim * sizeof(float));
350  }
351 
352  #pragma omp barrier
353 
354  /* Step 4: Output projection (parallel) */
355  /* First quantize attention output */
356  #pragma omp single
357  {
358  quantize_row_q8_k(attn_out, ln1_q8, H * aligned_head); /* Reuse ln1_q8 */
359  }
360 
361  gemv_q4_k_q8_k_parallel_simd(o_out, WO, ln1_q8, aligned_embed, H * head_dim, ith, nth);
362 
363  #pragma omp barrier
364 
365  /* Step 5: Residual add (parallel) */
366  residual_add_parallel(hidden, o_out, hidden, embed_dim, ith, nth);
367 
368  #pragma omp barrier
369 
370  /* ============================================================
371  * MLP BLOCK
372  * ============================================================ */
373 
374  /* Step 6: RMSNorm */
375  #pragma omp single
376  {
377  rmsnorm(hidden, (const float *)ln2_weight, ln2_out, embed_dim, eps);
378  quantize_row_q8_k(ln2_out, ln2_q8, aligned_embed);
379  }
380 
381  #pragma omp barrier
382 
383  /* Step 7: Gate and Up projections (parallel) */
384  gemv_q4_k_q8_k_parallel_simd(gate_buf, W_gate, ln2_q8, aligned_inter, aligned_embed, ith, nth);
385  gemv_q4_k_q8_k_parallel_simd(up_buf, W_up, ln2_q8, aligned_inter, aligned_embed, ith, nth);
386 
387  #pragma omp barrier
388 
389  /* Step 8: SwiGLU (parallel element-wise) */
390  const int dr = (intermediate + nth - 1) / nth;
391  const int r0 = dr * ith;
392  const int r1 = (r0 + dr < intermediate) ? (r0 + dr) : intermediate;
393 
394  for (int i = r0; i < r1; i++) {
395  float g = gate_buf[i];
396  float silu_g = g / (1.0f + expf(-g));
397  swiglu_buf[i] = silu_g * up_buf[i];
398  }
399 
400  #pragma omp barrier
401 
402  /* Step 9: Down projection */
403  #pragma omp single
404  {
405  quantize_row_q8_k(swiglu_buf, down_q8, aligned_inter);
406  }
407 
408  gemv_q4_k_q8_k_parallel_simd(mlp_out, W_down, down_q8, aligned_embed, aligned_inter, ith, nth);
409 
410  #pragma omp barrier
411 
412  /* Step 10: Final residual add */
413  residual_add_parallel(hidden, mlp_out, hidden, embed_dim, ith, nth);
414  }
415 }
void quantize_row_q8_k(const float *x, void *y, int k)
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
static void residual_add_parallel(const float *a, const float *b, float *out, int n, int ith, int nth)

References gemv_q4_k_q8_k_parallel_simd(), quantize_row_q8_k(), and residual_add_parallel().

◆ get_optimal_decode_threads()

int get_optimal_decode_threads ( void  )

Definition at line 422 of file parallel_orchestration.c.

423 {
424  int max_threads = omp_get_max_threads();
425 
426  /* For memory-bound workloads, 4 threads is often optimal.
427  * More threads hit memory bandwidth limits with diminishing returns.
428  * See MEMORY_BANDWIDTH_ANALYSIS.md for details. */
429  if (max_threads >= 4) {
430  return 4;
431  }
432  return max_threads;
433 }

◆ mlp_parallel()

void mlp_parallel ( const void *  ln2_q8,
const void *  W_gate,
const void *  W_up,
const void *  W_down,
float *  gate_buf,
float *  up_buf,
float *  swiglu_buf,
void *  down_q8,
float *  mlp_out,
int  intermediate,
int  embed_dim,
int  num_threads 
)

Parallel MLP (gate/up + SwiGLU + down projection).

Parameters
ln2_q8Input: RMSNorm output quantized to Q8_K
W_gateGate weights [intermediate, embed] in Q4_K
W_upUp weights [intermediate, embed] in Q4_K
W_downDown weights [embed, intermediate] in Q4_K
gate_bufScratch: gate output [intermediate]
up_bufScratch: up output [intermediate]
swiglu_bufScratch: SwiGLU output [intermediate]
down_q8Scratch: down input quantized [intermediate Q8_K blocks]
mlp_outOutput: MLP output [embed]
intermediateIntermediate dimension
embed_dimEmbedding dimension
num_threadsNumber of threads (0 = auto)

Definition at line 184 of file parallel_orchestration.c.

196 {
197  const int aligned_embed = ((embed_dim + 255) / 256) * 256;
198  const int aligned_inter = ((intermediate + 255) / 256) * 256;
199 
200  if (num_threads <= 0) {
201  num_threads = omp_get_max_threads();
202  }
203 
204  #pragma omp parallel num_threads(num_threads)
205  {
206  const int ith = omp_get_thread_num();
207  const int nth = omp_get_num_threads();
208 
209  /* Gate and Up projections (can run in parallel, no dependency) */
210  gemv_q4_k_q8_k_parallel_simd(gate_buf, W_gate, ln2_q8, aligned_inter, aligned_embed, ith, nth);
211  gemv_q4_k_q8_k_parallel_simd(up_buf, W_up, ln2_q8, aligned_inter, aligned_embed, ith, nth);
212 
213  #pragma omp barrier /* Wait for gate and up to complete */
214 
215  /* SwiGLU: swiglu = silu(gate) * up */
216  /* This is element-wise, parallelize across elements */
217  const int dr = (aligned_inter + nth - 1) / nth;
218  const int r0 = dr * ith;
219  const int r1 = (r0 + dr < aligned_inter) ? (r0 + dr) : aligned_inter;
220 
221  for (int i = r0; i < r1 && i < intermediate; i++) {
222  float g = gate_buf[i];
223  float silu_g = g / (1.0f + expf(-g)); /* SiLU activation */
224  swiglu_buf[i] = silu_g * up_buf[i];
225  }
226 
227  #pragma omp barrier /* Wait for SwiGLU to complete */
228 
229  /* Down projection: only thread 0 quantizes (single-threaded) */
230  /* TODO: Add parallel quantization */
231  #pragma omp single
232  {
233  quantize_row_q8_k(swiglu_buf, down_q8, aligned_inter);
234  }
235  /* Implicit barrier after omp single */
236 
237  /* Down projection */
238  gemv_q4_k_q8_k_parallel_simd(mlp_out, W_down, down_q8, aligned_embed, aligned_inter, ith, nth);
239  }
240 }

References gemv_q4_k_q8_k_parallel_simd(), and quantize_row_q8_k().

◆ qkv_projection_parallel()

void qkv_projection_parallel ( const void *  ln1_q8,
const void *  WQ,
const void *  WK,
const void *  WV,
float *  q_out,
float *  k_out,
float *  v_out,
int  H,
int  H_kv,
int  head_dim,
int  embed_dim,
int  num_threads 
)

Parallel Q/K/V projection for single token decode.

Parameters
ln1_q8Input: RMSNorm output quantized to Q8_K [aligned_embed]
WQQ weights [H*head_dim, aligned_embed] in Q4_K
WKK weights [H_kv*head_dim, aligned_embed] in Q4_K
WVV weights [H_kv*head_dim, aligned_embed] in Q4_K
q_outOutput: Q vectors [H, head_dim]
k_outOutput: K vector [H_kv, head_dim]
v_outOutput: V vector [H_kv, head_dim]
HNumber of query heads
H_kvNumber of KV heads (GQA)
head_dimHead dimension
embed_dimEmbedding dimension
num_threadsNumber of threads to use (0 = auto)

Definition at line 128 of file parallel_orchestration.c.

138 {
139  const int q_dim = H * head_dim;
140  const int kv_dim = H_kv * head_dim;
141 
142  /* Align to QK_K for quantized matmul */
143  const int aligned_embed = ((embed_dim + 255) / 256) * 256;
144 
145  if (num_threads <= 0) {
146  num_threads = omp_get_max_threads();
147  }
148 
149  /* Single OpenMP region for all three projections */
150  #pragma omp parallel num_threads(num_threads)
151  {
152  const int ith = omp_get_thread_num();
153  const int nth = omp_get_num_threads();
154 
155  /* Q projection: largest, benefits most from parallelism */
156  gemv_q4_k_q8_k_parallel_simd(q_out, WQ, ln1_q8, q_dim, aligned_embed, ith, nth);
157 
158  /* K projection: smaller, but still benefits */
159  gemv_q4_k_q8_k_parallel_simd(k_out, WK, ln1_q8, kv_dim, aligned_embed, ith, nth);
160 
161  /* V projection */
162  gemv_q4_k_q8_k_parallel_simd(v_out, WV, ln1_q8, kv_dim, aligned_embed, ith, nth);
163 
164  /* Implicit barrier at end of parallel region */
165  }
166 }

References gemv_q4_k_q8_k_parallel_simd().

◆ residual_add_parallel()

static void residual_add_parallel ( const float *  a,
const float *  b,
float *  out,
int  n,
int  ith,
int  nth 
)
static

Single-token decode with parallel SIMD kernels.

This is the main decode function that processes one token through all layers. OpenMP parallel region is created ONCE at the top, and all kernels receive (ith, nth) to split their work.

Pattern: #pragma omp parallel { int ith = omp_get_thread_num(); int nth = omp_get_num_threads();

// Each kernel processes only its slice gemv_q4_k_q8_k_parallel_simd(..., ith, nth); #pragma omp barrier

rmsnorm_parallel(..., ith, nth); // (not implemented yet) #pragma omp barrier ... }

Definition at line 63 of file parallel_orchestration.c.

66 {
67  const int dr = (n + nth - 1) / nth;
68  const int r0 = dr * ith;
69  const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
70 
71  if (r0 >= n) return;
72 
73  for (int i = r0; i < r1; i++) {
74  out[i] = a[i] + b[i];
75  }
76 }

Referenced by decode_layer_parallel().

◆ vec_scale_parallel()

static void vec_scale_parallel ( float *  y,
float  scale,
int  n,
int  ith,
int  nth 
)
static

Definition at line 79 of file parallel_orchestration.c.

81 {
82  const int dr = (n + nth - 1) / nth;
83  const int r0 = dr * ith;
84  const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
85 
86  if (r0 >= n) return;
87 
88  for (int i = r0; i < r1; i++) {
89  y[i] *= scale;
90  }
91 }

◆ vec_zero_parallel()

static void vec_zero_parallel ( float *  y,
int  n,
int  ith,
int  nth 
)
static

Definition at line 94 of file parallel_orchestration.c.

95 {
96  const int dr = (n + nth - 1) / nth;
97  const int r0 = dr * ith;
98  const int r1 = (r0 + dr < n) ? (r0 + dr) : n;
99 
100  if (r0 >= n) return;
101 
102  memset(y + r0, 0, (r1 - r0) * sizeof(float));
103 }