← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ck_threadpool.c
Go to the documentation of this file.
1 /**
2  * @file ck_threadpool.c
3  * @brief Persistent pthread thread pool for CK-Engine inference
4  *
5  * Architecture:
6  * - N-1 worker pthreads created at startup, main thread is thread 0
7  * - Workers spin on atomic dispatch counter waiting for work
8  * - Barriers use atomic counter + spin-wait with _mm_pause()
9  * - Hybrid polling: spin CK_THREADPOOL_SPIN_COUNT rounds, then condvar
10  * - All atomics on separate cache lines to avoid false sharing
11  *
12  * Based on the ggml_threadpool design from llama.cpp, adapted for
13  * CK-Engine's kernel dispatch model.
14  */
15 
16 #include "ck_threadpool.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 #include <stdio.h>
21 #include <errno.h>
22 
23 #ifdef __x86_64__
24 #include <immintrin.h>
25 #define CK_SPIN_PAUSE() _mm_pause()
26 #else
27 #define CK_SPIN_PAUSE() ((void)0)
28 #endif
29 
30 /* ============================================================================
31  * Internal Structures (cache-line aligned)
32  * ============================================================================ */
33 
34 /** Per-worker state */
35 typedef struct {
36  pthread_t thread;
37  int id; /* 0 = main, 1..n-1 = workers */
38  struct ck_threadpool *pool;
39 } ck_worker_t;
40 
41 /** Barrier state — all fields on separate cache lines */
42 typedef struct {
43  _Alignas(CK_CACHE_LINE) atomic_int n_arrived;
44  _Alignas(CK_CACHE_LINE) atomic_int n_phase;
45  int n_threads;
46  char _pad[CK_CACHE_LINE - sizeof(int)];
47 } ck_barrier_t;
48 
49 /** Thread pool (opaque) */
50 struct ck_threadpool {
51  /* Dispatch state — cache-line aligned */
52  _Alignas(CK_CACHE_LINE) atomic_int n_dispatch; /* bumped to wake workers */
53  _Alignas(CK_CACHE_LINE) atomic_int n_complete; /* workers signal completion */
54  _Alignas(CK_CACHE_LINE) ck_work_fn_t work_fn; /* current work function */
55  void *work_args; /* current work arguments */
56 
57  /* Barrier for intra-dispatch synchronization */
58  ck_barrier_t barrier;
59 
60  /* Worker management */
61  int n_threads; /* total threads (including main) */
62  ck_worker_t workers[CK_THREADPOOL_MAX_THREADS];
63 
64  /* Shutdown / pause signals */
65  _Alignas(CK_CACHE_LINE) atomic_int stop;
66  _Alignas(CK_CACHE_LINE) atomic_int paused;
67 
68  /* Condvar for sleep/wake (hybrid polling) */
69  pthread_mutex_t mutex;
70  pthread_cond_t cond_dispatch; /* workers wait here when sleeping */
71  pthread_cond_t cond_done; /* main waits here for completion */
72 };
73 
74 /* ============================================================================
75  * Barrier Implementation
76  * ============================================================================ */
77 
78 static void barrier_init(ck_barrier_t *b, int n_threads)
79 {
80  atomic_store(&b->n_arrived, 0);
81  atomic_store(&b->n_phase, 0);
82  b->n_threads = n_threads;
83 }
84 
85 /**
86  * Spin-wait barrier. All threads must call this.
87  * Uses phase counter to allow re-use without reset.
88  */
89 static void barrier_wait(ck_barrier_t *b)
90 {
91  const int n = b->n_threads;
92  const int phase = atomic_load_explicit(&b->n_phase, memory_order_relaxed);
93 
94  if (atomic_fetch_add_explicit(&b->n_arrived, 1, memory_order_acq_rel) == n - 1) {
95  /* Last thread to arrive — reset and advance phase */
96  atomic_store_explicit(&b->n_arrived, 0, memory_order_relaxed);
97  atomic_store_explicit(&b->n_phase, phase + 1, memory_order_release);
98  } else {
99  /* Spin until phase advances */
100  int spins = 0;
101  while (atomic_load_explicit(&b->n_phase, memory_order_acquire) == phase) {
102  CK_SPIN_PAUSE();
103  spins++;
104  /* After many spins, yield to avoid wasting CPU on oversubscribed systems */
105  if (spins > CK_THREADPOOL_SPIN_COUNT * 16) {
106  sched_yield();
107  spins = 0;
108  }
109  }
110  }
111 }
112 
113 /* ============================================================================
114  * Worker Thread
115  * ============================================================================ */
116 
117 static void *worker_main(void *arg)
118 {
119  ck_worker_t *w = (ck_worker_t *)arg;
120  ck_threadpool_t *pool = w->pool;
121  const int ith = w->id;
122  int last_dispatch = 0;
123 
124  for (;;) {
125  /* Spin-wait for new dispatch */
126  int spins = 0;
127  for (;;) {
128  /* Check shutdown */
129  if (atomic_load_explicit(&pool->stop, memory_order_acquire)) {
130  return NULL;
131  }
132 
133  /* Check for new work */
134  int current = atomic_load_explicit(&pool->n_dispatch, memory_order_acquire);
135  if (current != last_dispatch) {
136  last_dispatch = current;
137  break;
138  }
139 
140  CK_SPIN_PAUSE();
141  spins++;
142 
143  /* After spinning, fall back to condvar sleep */
144  if (spins >= CK_THREADPOOL_SPIN_COUNT) {
145  pthread_mutex_lock(&pool->mutex);
146  /* Re-check under lock to avoid missed wakeup */
147  current = atomic_load_explicit(&pool->n_dispatch, memory_order_acquire);
148  if (current == last_dispatch &&
149  !atomic_load_explicit(&pool->stop, memory_order_acquire)) {
150  pthread_cond_wait(&pool->cond_dispatch, &pool->mutex);
151  }
152  pthread_mutex_unlock(&pool->mutex);
153  spins = 0;
154  }
155  }
156 
157  /* Execute work */
158  ck_work_fn_t fn = pool->work_fn;
159  void *args = pool->work_args;
160  if (fn) {
161  fn(ith, pool->n_threads, args);
162  }
163 
164  /* Signal completion */
165  if (atomic_fetch_add_explicit(&pool->n_complete, 1, memory_order_acq_rel)
166  == pool->n_threads - 2) {
167  /* Last worker done — wake main thread if it's waiting */
168  pthread_mutex_lock(&pool->mutex);
169  pthread_cond_signal(&pool->cond_done);
170  pthread_mutex_unlock(&pool->mutex);
171  }
172  }
173 
174  return NULL;
175 }
176 
177 /* ============================================================================
178  * Lifecycle
179  * ============================================================================ */
180 
181 extern int ck_get_physical_cores(void);
182 
183 ck_threadpool_t *ck_threadpool_create(int n_threads)
184 {
185  if (n_threads <= 0) {
186  n_threads = ck_get_physical_cores();
187  if (n_threads <= 0) n_threads = 1;
188  /* Cap at reasonable default for memory-bound workloads */
189  if (n_threads > 8) n_threads = 8;
190  }
191  if (n_threads > CK_THREADPOOL_MAX_THREADS) {
192  n_threads = CK_THREADPOOL_MAX_THREADS;
193  }
194 
195  ck_threadpool_t *pool = aligned_alloc(CK_CACHE_LINE, sizeof(ck_threadpool_t));
196  if (!pool) return NULL;
197  memset(pool, 0, sizeof(*pool));
198 
199  pool->n_threads = n_threads;
200  atomic_store(&pool->n_dispatch, 0);
201  atomic_store(&pool->n_complete, 0);
202  atomic_store(&pool->stop, 0);
203  atomic_store(&pool->paused, 0);
204  pool->work_fn = NULL;
205  pool->work_args = NULL;
206 
207  barrier_init(&pool->barrier, n_threads);
208 
209  pthread_mutex_init(&pool->mutex, NULL);
210  pthread_cond_init(&pool->cond_dispatch, NULL);
211  pthread_cond_init(&pool->cond_done, NULL);
212 
213  /* Thread 0 = main thread (no pthread created) */
214  pool->workers[0].id = 0;
215  pool->workers[0].pool = pool;
216  pool->workers[0].thread = pthread_self();
217 
218  /* Spawn N-1 worker threads */
219  for (int i = 1; i < n_threads; i++) {
220  pool->workers[i].id = i;
221  pool->workers[i].pool = pool;
222 
223  int rc = pthread_create(&pool->workers[i].thread, NULL,
224  worker_main, &pool->workers[i]);
225  if (rc != 0) {
226  fprintf(stderr, "[CK threadpool] Failed to create worker %d: %s\n",
227  i, strerror(rc));
228  /* Reduce thread count to what we managed to create */
229  pool->n_threads = i;
230  barrier_init(&pool->barrier, i);
231  break;
232  }
233  }
234 
235  if (pool->n_threads > 1) {
236  fprintf(stderr, "[CK threadpool] Created %d threads (1 main + %d workers)\n",
237  pool->n_threads, pool->n_threads - 1);
238  }
239 
240  return pool;
241 }
242 
243 void ck_threadpool_destroy(ck_threadpool_t *pool)
244 {
245  if (!pool) return;
246 
247  /* Signal shutdown */
248  atomic_store_explicit(&pool->stop, 1, memory_order_release);
249 
250  /* Wake all sleeping workers */
251  pthread_mutex_lock(&pool->mutex);
252  pthread_cond_broadcast(&pool->cond_dispatch);
253  pthread_mutex_unlock(&pool->mutex);
254 
255  /* Join all worker threads */
256  for (int i = 1; i < pool->n_threads; i++) {
257  pthread_join(pool->workers[i].thread, NULL);
258  }
259 
260  pthread_cond_destroy(&pool->cond_dispatch);
261  pthread_cond_destroy(&pool->cond_done);
262  pthread_mutex_destroy(&pool->mutex);
263 
264  free(pool);
265 }
266 
267 /* ============================================================================
268  * Dispatch & Synchronization
269  * ============================================================================ */
270 
271 void ck_threadpool_dispatch(ck_threadpool_t *pool, ck_work_fn_t fn, void *args)
272 {
273  if (!pool || !fn) return;
274 
275  /* Single-thread fast path: just call directly */
276  if (pool->n_threads == 1) {
277  fn(0, 1, args);
278  return;
279  }
280 
281  /* Reset barrier phase for this dispatch */
282  barrier_init(&pool->barrier, pool->n_threads);
283 
284  /* Set work descriptor */
285  pool->work_fn = fn;
286  pool->work_args = args;
287  atomic_store_explicit(&pool->n_complete, 0, memory_order_release);
288 
289  /* Wake workers by bumping dispatch counter */
290  atomic_fetch_add_explicit(&pool->n_dispatch, 1, memory_order_release);
291 
292  /* Also signal condvar for sleeping workers */
293  pthread_mutex_lock(&pool->mutex);
294  pthread_cond_broadcast(&pool->cond_dispatch);
295  pthread_mutex_unlock(&pool->mutex);
296 
297  /* Main thread (ith=0) does its share */
298  fn(0, pool->n_threads, args);
299 
300  /* Wait for all workers to complete */
301  if (pool->n_threads > 1) {
302  int spins = 0;
303  while (atomic_load_explicit(&pool->n_complete, memory_order_acquire)
304  < pool->n_threads - 1) {
305  CK_SPIN_PAUSE();
306  spins++;
307  if (spins >= CK_THREADPOOL_SPIN_COUNT) {
308  pthread_mutex_lock(&pool->mutex);
309  if (atomic_load_explicit(&pool->n_complete, memory_order_acquire)
310  < pool->n_threads - 1) {
311  pthread_cond_wait(&pool->cond_done, &pool->mutex);
312  }
313  pthread_mutex_unlock(&pool->mutex);
314  spins = 0;
315  }
316  }
317  }
318 }
319 
320 void ck_threadpool_barrier(ck_threadpool_t *pool)
321 {
322  if (!pool || pool->n_threads <= 1) return;
323  barrier_wait(&pool->barrier);
324 }
325 
326 /* ============================================================================
327  * Power Management
328  * ============================================================================ */
329 
330 void ck_threadpool_pause(ck_threadpool_t *pool)
331 {
332  if (!pool) return;
333  atomic_store_explicit(&pool->paused, 1, memory_order_release);
334 }
335 
336 void ck_threadpool_resume(ck_threadpool_t *pool)
337 {
338  if (!pool) return;
339  atomic_store_explicit(&pool->paused, 0, memory_order_release);
340 
341  /* Wake sleeping workers */
342  pthread_mutex_lock(&pool->mutex);
343  pthread_cond_broadcast(&pool->cond_dispatch);
344  pthread_mutex_unlock(&pool->mutex);
345 }
346 
347 /* ============================================================================
348  * Queries
349  * ============================================================================ */
350 
351 int ck_threadpool_n_threads(const ck_threadpool_t *pool)
352 {
353  return pool ? pool->n_threads : 1;
354 }
355 
356 int ck_threadpool_thread_id(const ck_threadpool_t *pool)
357 {
358  if (!pool) return -1;
359  pthread_t self = pthread_self();
360  for (int i = 0; i < pool->n_threads; i++) {
361  if (pthread_equal(self, pool->workers[i].thread)) {
362  return i;
363  }
364  }
365  return -1;
366 }
367 
368 /* ============================================================================
369  * Global Thread Pool
370  * ============================================================================ */
371 
372 static ck_threadpool_t *g_threadpool = NULL;
373 static pthread_once_t g_threadpool_once = PTHREAD_ONCE_INIT;
374 
375 extern int ck_get_num_threads(void);
376 
377 static void global_pool_init(void)
378 {
379  int n = ck_get_num_threads();
381 }
382 
383 ck_threadpool_t *ck_threadpool_global(void)
384 {
385  pthread_once(&g_threadpool_once, global_pool_init);
386  return g_threadpool;
387 }
388 
390 {
391  if (g_threadpool) {
393  g_threadpool = NULL;
394  /* Reset once control so pool can be re-created if needed */
395  g_threadpool_once = PTHREAD_ONCE_INIT;
396  }
397 }
void ck_threadpool_pause(ck_threadpool_t *pool)
static void barrier_init(ck_barrier_t *b, int n_threads)
Definition: ck_threadpool.c:78
static void barrier_wait(ck_barrier_t *b)
Definition: ck_threadpool.c:89
void ck_threadpool_resume(ck_threadpool_t *pool)
static pthread_once_t g_threadpool_once
void ck_threadpool_global_destroy(void)
static void global_pool_init(void)
ck_threadpool_t * ck_threadpool_create(int n_threads)
void ck_threadpool_destroy(ck_threadpool_t *pool)
static void * worker_main(void *arg)
void ck_threadpool_barrier(ck_threadpool_t *pool)
int ck_get_physical_cores(void)
static ck_threadpool_t * g_threadpool
void ck_threadpool_dispatch(ck_threadpool_t *pool, ck_work_fn_t fn, void *args)
int ck_threadpool_thread_id(const ck_threadpool_t *pool)
#define CK_SPIN_PAUSE()
Definition: ck_threadpool.c:27
int ck_get_num_threads(void)
int ck_threadpool_n_threads(const ck_threadpool_t *pool)
ck_threadpool_t * ck_threadpool_global(void)
Persistent pthread thread pool for CK-Engine inference.
#define CK_THREADPOOL_MAX_THREADS
Definition: ck_threadpool.h:48
#define CK_CACHE_LINE
Definition: ck_threadpool.h:54
#define CK_THREADPOOL_SPIN_COUNT
Definition: ck_threadpool.h:51
void(* ck_work_fn_t)(int ith, int nth, void *args)
Definition: ck_threadpool.h:68
int32_t id
Definition: tokenizer.h:315