112 printf(
"Usage: %s <libmodel.so> <weights.bump>\n", argv[0]);
116 SharedState state = {0};
117 pthread_mutex_init(&state.mutex, NULL);
118 pthread_cond_init(&state.cond_task, NULL);
119 pthread_cond_init(&state.cond_done, NULL);
120 state.weights_path = argv[2];
122 void *handle = dlopen(argv[1], RTLD_NOW);
123 if (!handle) { fprintf(stderr,
"%s\n", dlerror());
return 1; }
125 state.init = dlsym(handle,
"ck_model_init");
126 state.embed = dlsym(handle,
"ck_model_embed_tokens");
127 state.forward = dlsym(handle,
"ck_model_forward");
128 state.kv_enable = dlsym(handle,
"ck_model_kv_cache_enable");
129 state.decode = dlsym(handle,
"ck_model_decode");
130 state.sample = dlsym(handle,
"ck_model_sample_argmax");
131 state.get_context = dlsym(handle,
"ck_model_get_context_window");
132 get_ptr_t get_offsets = dlsym(handle,
"ck_model_get_vocab_offsets");
133 get_ptr_t get_strings = dlsym(handle,
"ck_model_get_vocab_strings");
134 get_int_t get_vocab_size = dlsym(handle,
"ck_model_get_vocab_size");
135 get_int_t get_num_merges = dlsym(handle,
"ck_model_get_num_merges");
138 pthread_t engine_thread;
151 if (!fgets(input,
sizeof(input), stdin))
break;
152 if (strncmp(input,
"/exit", 5) == 0)
break;
159 pthread_mutex_lock(&state.mutex);
160 memcpy(state.prompt_tokens,
ids, n *
sizeof(int32_t));
163 state.task_ready =
true;
164 state.token_ready =
false;
165 pthread_cond_signal(&state.cond_task);
166 pthread_mutex_unlock(&state.mutex);
168 printf(
"Assistant: ");
173 pthread_mutex_lock(&state.mutex);
174 while (!state.token_ready && !state.quit) {
175 pthread_cond_wait(&state.cond_done, &state.mutex);
177 int32_t tok = state.last_token;
178 state.token_ready =
false;
179 pthread_mutex_unlock(&state.mutex);
181 if (tok == 151643 || tok == 151645)
break;
185 if ((
unsigned char)word[0] == 0xC4 && (
unsigned char)word[1] == 0xA0) {
186 printf(
" %s", word + 2);
196 pthread_mutex_lock(&state.mutex);
198 pthread_cond_signal(&state.cond_task);
199 pthread_mutex_unlock(&state.mutex);
200 pthread_join(engine_thread, NULL);
void *(* get_ptr_t)(void)
void * engine_thread_func(void *arg)
const char * ck_tokenizer_id_to_token(const CKTokenizer *tok, int32_t id)
int ck_tokenizer_encode(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
CKTokenizer * ck_tokenizer_create(CKTokenizerType type)
int ck_tokenizer_load_binary(CKTokenizer *tok, int vocab_size, const int32_t *offsets, const char *strings, int num_merges, const int32_t *merges)