13 typedef int (*
init_t)(
const char *weights_path);
14 typedef int (*
embed_t)(
const int32_t *tokens,
int num_tokens);
20 typedef void* (*get_ptr_t)(void);
25 pthread_mutex_t mutex;
26 pthread_cond_t cond_task;
27 pthread_cond_t cond_done;
30 int32_t prompt_tokens[1024];
49 const char *weights_path;
54 SharedState *s = (SharedState *)arg;
56 printf(
"[Engine] Thread started. Initializing model...\n");
57 if (s->init(s->weights_path) != 0) {
58 fprintf(stderr,
"[Engine] Failed to init model\n");
61 s->kv_enable(s->get_context());
64 pthread_mutex_lock(&s->mutex);
65 while (!s->task_ready && !s->quit) {
66 pthread_cond_wait(&s->cond_task, &s->mutex);
69 pthread_mutex_unlock(&s->mutex);
74 int n_prompt = s->n_prompt;
76 memcpy(prompt, s->prompt_tokens, n_prompt *
sizeof(int32_t));
77 int max_gen = s->max_gen;
78 s->task_ready =
false;
79 pthread_mutex_unlock(&s->mutex);
82 s->embed(prompt, n_prompt);
84 int32_t next_token = s->sample();
87 pthread_mutex_lock(&s->mutex);
88 s->last_token = next_token;
89 s->token_ready =
true;
90 pthread_cond_signal(&s->cond_done);
91 pthread_mutex_unlock(&s->mutex);
94 for (
int i = 0; i < max_gen; i++) {
95 if (s->decode(next_token, NULL) != 0)
break;
96 next_token = s->sample();
98 pthread_mutex_lock(&s->mutex);
99 s->last_token = next_token;
100 s->token_ready =
true;
101 pthread_cond_signal(&s->cond_done);
102 pthread_mutex_unlock(&s->mutex);
104 if (next_token == 151643 || next_token == 151645)
break;
110 int main(
int argc,
char **argv) {
112 printf(
"Usage: %s <libmodel.so> <weights.bump>\n", argv[0]);
116 SharedState state = {0};
117 pthread_mutex_init(&state.mutex, NULL);
118 pthread_cond_init(&state.cond_task, NULL);
119 pthread_cond_init(&state.cond_done, NULL);
120 state.weights_path = argv[2];
122 void *handle = dlopen(argv[1], RTLD_NOW);
123 if (!handle) { fprintf(stderr,
"%s\n", dlerror());
return 1; }
125 state.init = dlsym(handle,
"ck_model_init");
126 state.embed = dlsym(handle,
"ck_model_embed_tokens");
127 state.forward = dlsym(handle,
"ck_model_forward");
128 state.kv_enable = dlsym(handle,
"ck_model_kv_cache_enable");
129 state.decode = dlsym(handle,
"ck_model_decode");
130 state.sample = dlsym(handle,
"ck_model_sample_argmax");
131 state.get_context = dlsym(handle,
"ck_model_get_context_window");
132 get_ptr_t get_offsets = dlsym(handle,
"ck_model_get_vocab_offsets");
133 get_ptr_t get_strings = dlsym(handle,
"ck_model_get_vocab_strings");
134 get_int_t get_vocab_size = dlsym(handle,
"ck_model_get_vocab_size");
135 get_int_t get_num_merges = dlsym(handle,
"ck_model_get_num_merges");
138 pthread_t engine_thread;
151 if (!fgets(input,
sizeof(input), stdin))
break;
152 if (strncmp(input,
"/exit", 5) == 0)
break;
159 pthread_mutex_lock(&state.mutex);
160 memcpy(state.prompt_tokens,
ids, n *
sizeof(int32_t));
163 state.task_ready =
true;
164 state.token_ready =
false;
165 pthread_cond_signal(&state.cond_task);
166 pthread_mutex_unlock(&state.mutex);
168 printf(
"Assistant: ");
173 pthread_mutex_lock(&state.mutex);
174 while (!state.token_ready && !state.quit) {
175 pthread_cond_wait(&state.cond_done, &state.mutex);
177 int32_t tok = state.last_token;
178 state.token_ready =
false;
179 pthread_mutex_unlock(&state.mutex);
181 if (tok == 151643 || tok == 151645)
break;
185 if ((
unsigned char)word[0] == 0xC4 && (
unsigned char)word[1] == 0xA0) {
186 printf(
" %s", word + 2);
196 pthread_mutex_lock(&state.mutex);
198 pthread_cond_signal(&state.cond_task);
199 pthread_mutex_unlock(&state.mutex);
200 pthread_join(engine_thread, NULL);
void *(* get_ptr_t)(void)
int(* init_t)(const char *weights_path)
void(* void_func_t)(void)
int(* embed_t)(const int32_t *tokens, int num_tokens)
int(* kv_enable_t)(int capacity)
int main(int argc, char **argv)
int(* forward_t)(float *logits_out)
int(* decode_t)(int32_t token, float *logits_out)
void * engine_thread_func(void *arg)
int(* sample_argmax_t)(void)
const char * ck_tokenizer_id_to_token(const CKTokenizer *tok, int32_t id)
int ck_tokenizer_encode(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
CKTokenizer * ck_tokenizer_create(CKTokenizerType type)
int ck_tokenizer_load_binary(CKTokenizer *tok, int vocab_size, const int32_t *offsets, const char *strings, int num_merges, const int32_t *merges)