24 #include <immintrin.h>
25 #define CK_SPIN_PAUSE() _mm_pause()
27 #define CK_SPIN_PAUSE() ((void)0)
38 struct ck_threadpool *pool;
50 struct ck_threadpool {
69 pthread_mutex_t mutex;
70 pthread_cond_t cond_dispatch;
71 pthread_cond_t cond_done;
80 atomic_store(&b->n_arrived, 0);
81 atomic_store(&b->n_phase, 0);
82 b->n_threads = n_threads;
91 const int n = b->n_threads;
92 const int phase = atomic_load_explicit(&b->n_phase, memory_order_relaxed);
94 if (atomic_fetch_add_explicit(&b->n_arrived, 1, memory_order_acq_rel) == n - 1) {
96 atomic_store_explicit(&b->n_arrived, 0, memory_order_relaxed);
97 atomic_store_explicit(&b->n_phase, phase + 1, memory_order_release);
101 while (atomic_load_explicit(&b->n_phase, memory_order_acquire) == phase) {
119 ck_worker_t *w = (ck_worker_t *)arg;
120 ck_threadpool_t *pool = w->pool;
121 const int ith = w->id;
122 int last_dispatch = 0;
129 if (atomic_load_explicit(&pool->stop, memory_order_acquire)) {
134 int current = atomic_load_explicit(&pool->n_dispatch, memory_order_acquire);
135 if (current != last_dispatch) {
136 last_dispatch = current;
145 pthread_mutex_lock(&pool->mutex);
147 current = atomic_load_explicit(&pool->n_dispatch, memory_order_acquire);
148 if (current == last_dispatch &&
149 !atomic_load_explicit(&pool->stop, memory_order_acquire)) {
150 pthread_cond_wait(&pool->cond_dispatch, &pool->mutex);
152 pthread_mutex_unlock(&pool->mutex);
159 void *args = pool->work_args;
161 fn(ith, pool->n_threads, args);
165 if (atomic_fetch_add_explicit(&pool->n_complete, 1, memory_order_acq_rel)
166 == pool->n_threads - 2) {
168 pthread_mutex_lock(&pool->mutex);
169 pthread_cond_signal(&pool->cond_done);
170 pthread_mutex_unlock(&pool->mutex);
185 if (n_threads <= 0) {
187 if (n_threads <= 0) n_threads = 1;
189 if (n_threads > 8) n_threads = 8;
195 ck_threadpool_t *pool = aligned_alloc(
CK_CACHE_LINE,
sizeof(ck_threadpool_t));
196 if (!pool)
return NULL;
197 memset(pool, 0,
sizeof(*pool));
199 pool->n_threads = n_threads;
200 atomic_store(&pool->n_dispatch, 0);
201 atomic_store(&pool->n_complete, 0);
202 atomic_store(&pool->stop, 0);
203 atomic_store(&pool->paused, 0);
204 pool->work_fn = NULL;
205 pool->work_args = NULL;
209 pthread_mutex_init(&pool->mutex, NULL);
210 pthread_cond_init(&pool->cond_dispatch, NULL);
211 pthread_cond_init(&pool->cond_done, NULL);
214 pool->workers[0].id = 0;
215 pool->workers[0].pool = pool;
216 pool->workers[0].thread = pthread_self();
219 for (
int i = 1; i < n_threads; i++) {
220 pool->workers[i].id = i;
221 pool->workers[i].pool = pool;
223 int rc = pthread_create(&pool->workers[i].thread, NULL,
226 fprintf(stderr,
"[CK threadpool] Failed to create worker %d: %s\n",
235 if (pool->n_threads > 1) {
236 fprintf(stderr,
"[CK threadpool] Created %d threads (1 main + %d workers)\n",
237 pool->n_threads, pool->n_threads - 1);
248 atomic_store_explicit(&pool->stop, 1, memory_order_release);
251 pthread_mutex_lock(&pool->mutex);
252 pthread_cond_broadcast(&pool->cond_dispatch);
253 pthread_mutex_unlock(&pool->mutex);
256 for (
int i = 1; i < pool->n_threads; i++) {
257 pthread_join(pool->workers[i].thread, NULL);
260 pthread_cond_destroy(&pool->cond_dispatch);
261 pthread_cond_destroy(&pool->cond_done);
262 pthread_mutex_destroy(&pool->mutex);
273 if (!pool || !fn)
return;
276 if (pool->n_threads == 1) {
286 pool->work_args = args;
287 atomic_store_explicit(&pool->n_complete, 0, memory_order_release);
290 atomic_fetch_add_explicit(&pool->n_dispatch, 1, memory_order_release);
293 pthread_mutex_lock(&pool->mutex);
294 pthread_cond_broadcast(&pool->cond_dispatch);
295 pthread_mutex_unlock(&pool->mutex);
298 fn(0, pool->n_threads, args);
301 if (pool->n_threads > 1) {
303 while (atomic_load_explicit(&pool->n_complete, memory_order_acquire)
304 < pool->n_threads - 1) {
308 pthread_mutex_lock(&pool->mutex);
309 if (atomic_load_explicit(&pool->n_complete, memory_order_acquire)
310 < pool->n_threads - 1) {
311 pthread_cond_wait(&pool->cond_done, &pool->mutex);
313 pthread_mutex_unlock(&pool->mutex);
322 if (!pool || pool->n_threads <= 1)
return;
333 atomic_store_explicit(&pool->paused, 1, memory_order_release);
339 atomic_store_explicit(&pool->paused, 0, memory_order_release);
342 pthread_mutex_lock(&pool->mutex);
343 pthread_cond_broadcast(&pool->cond_dispatch);
344 pthread_mutex_unlock(&pool->mutex);
353 return pool ? pool->n_threads : 1;
358 if (!pool)
return -1;
359 pthread_t
self = pthread_self();
360 for (
int i = 0; i < pool->n_threads; i++) {
361 if (pthread_equal(
self, pool->workers[i].thread)) {
void ck_threadpool_pause(ck_threadpool_t *pool)
static void barrier_init(ck_barrier_t *b, int n_threads)
static void barrier_wait(ck_barrier_t *b)
void ck_threadpool_resume(ck_threadpool_t *pool)
static pthread_once_t g_threadpool_once
void ck_threadpool_global_destroy(void)
static void global_pool_init(void)
ck_threadpool_t * ck_threadpool_create(int n_threads)
void ck_threadpool_destroy(ck_threadpool_t *pool)
static void * worker_main(void *arg)
void ck_threadpool_barrier(ck_threadpool_t *pool)
int ck_get_physical_cores(void)
static ck_threadpool_t * g_threadpool
void ck_threadpool_dispatch(ck_threadpool_t *pool, ck_work_fn_t fn, void *args)
int ck_threadpool_thread_id(const ck_threadpool_t *pool)
int ck_get_num_threads(void)
int ck_threadpool_n_threads(const ck_threadpool_t *pool)
ck_threadpool_t * ck_threadpool_global(void)
Persistent pthread thread pool for CK-Engine inference.
#define CK_THREADPOOL_MAX_THREADS
#define CK_THREADPOOL_SPIN_COUNT
void(* ck_work_fn_t)(int ith, int nth, void *args)