33 #include <readline/readline.h>
34 #include <readline/history.h>
40 #define CK_CLI_VERSION "6.6.0"
41 #define CK_CLI_DEFAULT_MAX_TOKENS 256
42 #define CK_CLI_EOS_MAX 8
43 #define CK_CLI_OUTPUT_BUF_SIZE 4096
44 #define CK_CLI_MAX_CONTEXT 32768
45 #define CK_CLI_HISTORY_FILE ".ck_cli_history"
69 typedef int (*
init_t)(
const char *weights_path);
70 typedef int (*
embed_t)(
const int32_t *tokens,
int num_tokens);
76 typedef float *(*get_logits_t)(void);
78 typedef void *(*get_ptr_t)(void);
117 const char *system_prefix;
118 const char *system_suffix;
119 const char *user_prefix;
120 const char *user_suffix;
121 const char *assistant_prefix;
122 const char *assistant_suffix;
128 .system_prefix =
"", .system_suffix =
"\n",
129 .user_prefix =
"", .user_suffix =
"\n",
130 .assistant_prefix =
"", .assistant_suffix =
"",
134 .system_prefix =
"<|im_start|>system\n",
135 .system_suffix =
"<|im_end|>\n",
136 .user_prefix =
"<|im_start|>user\n",
137 .user_suffix =
"<|im_end|>\n",
138 .assistant_prefix =
"<|im_start|>assistant\n",
139 .assistant_suffix =
"<|im_end|>",
143 .system_prefix =
"[INST] <<SYS>>\n",
144 .system_suffix =
"\n<</SYS>>\n\n",
146 .user_suffix =
" [/INST]",
147 .assistant_prefix =
" ",
148 .assistant_suffix =
" </s><s>[INST] ",
152 .system_prefix =
"<|im_start|>system\n",
153 .system_suffix =
"<|im_end|>\n",
154 .user_prefix =
"<|im_start|>user\n",
155 .user_suffix =
"<|im_end|>\n",
156 .assistant_prefix =
"<|im_start|>assistant\n",
157 .assistant_suffix =
"<|im_end|>",
162 .system_suffix =
"\n\n",
163 .user_prefix =
"[INST] ",
164 .user_suffix =
" [/INST]",
165 .assistant_prefix =
"",
166 .assistant_suffix =
"</s> ",
175 const char *model_name;
176 const char *lib_path;
177 const char *weights_path;
178 const char *prompt_once;
179 const char *system_prompt;
181 int context_override;
188 bool no_chat_template;
199 static char cache_path[4096];
200 const char *home = getenv(
"HOME");
201 if (!home) home =
"/tmp";
202 snprintf(cache_path,
sizeof(cache_path),
"%s/.cache/ck-engine-v6.6/models", home);
206 static bool find_model_in_cache(
const char *model_name,
char *lib_out,
char *weights_out,
size_t out_size) {
208 DIR *dir = opendir(cache_dir);
209 if (!dir)
return false;
211 struct dirent *entry;
212 while ((entry = readdir(dir)) != NULL) {
213 if (entry->d_name[0] ==
'.')
continue;
216 if (strstr(entry->d_name, model_name) != NULL) {
217 char model_dir[4096];
218 snprintf(model_dir,
sizeof(model_dir),
"%s/%s", cache_dir, entry->d_name);
221 char so_path[4096], bump_path[4096];
222 snprintf(so_path,
sizeof(so_path),
"%s/ck-kernel-inference.so", model_dir);
223 snprintf(bump_path,
sizeof(bump_path),
"%s/weights.bump", model_dir);
226 if (stat(so_path, &st) == 0 && stat(bump_path, &st) == 0) {
227 strncpy(lib_out, so_path, out_size - 1);
228 strncpy(weights_out, bump_path, out_size - 1);
243 if (!weights_path || !opt)
return false;
246 char vocab_path[4096];
247 const char *slash = strrchr(weights_path,
'/');
248 if (!slash)
return false;
250 size_t dir_len = (size_t)(slash - weights_path);
251 if (dir_len + 12 >=
sizeof(vocab_path))
return false;
253 memcpy(vocab_path, weights_path, dir_len);
254 vocab_path[dir_len] =
'\0';
255 strcat(vocab_path,
"/vocab.json");
257 FILE *f = fopen(vocab_path,
"r");
258 if (!f)
return false;
262 size_t n = fread(buf, 1,
sizeof(buf) - 1, f);
267 const char *st = strstr(buf,
"\"special_tokens\"");
268 if (!st)
return false;
271 const char *
eos = strstr(st,
"\"eos\"");
273 const char *colon = strchr(
eos,
':');
275 int eos_id = atoi(colon + 1);
277 opt->eos_ids[0] = eos_id;
284 const char *
bos = strstr(st,
"\"bos\"");
286 const char *colon = strchr(
bos,
':');
288 int bos_id = atoi(colon + 1);
289 if (bos_id > 0 && bos_id != opt->eos_ids[0]) {
290 opt->eos_ids[opt->eos_count++] = bos_id;
295 return opt->eos_count > 0;
300 DIR *dir = opendir(cache_dir);
302 fprintf(stderr,
"No models found in %s\n", cache_dir);
306 printf(
"Available models in %s:\n", cache_dir);
307 struct dirent *entry;
309 while ((entry = readdir(dir)) != NULL) {
310 if (entry->d_name[0] ==
'.')
continue;
312 char model_dir[4096];
313 snprintf(model_dir,
sizeof(model_dir),
"%s/%s", cache_dir, entry->d_name);
316 snprintf(so_path,
sizeof(so_path),
"%s/ck-kernel-inference.so", model_dir);
319 if (stat(so_path, &st) == 0) {
320 printf(
" - %s\n", entry->d_name);
327 printf(
" (none found)\n");
336 if (temperature <= 0.0f || top_p <= 0.0f) {
339 float best_val = logits[0];
341 if (logits[i] > best_val) {
342 best_val = logits[i];
350 float max_logit = logits[0];
352 if (logits[i] > max_logit) max_logit = logits[i];
357 logits[i] = expf((logits[i] - max_logit) / temperature);
369 float threshold = (float)rand() / (float)RAND_MAX * top_p;
372 int *indices = (
int *)malloc(
vocab_size *
sizeof(
int));
373 float *probs = (
float *)malloc(
vocab_size *
sizeof(
float));
376 probs[i] = logits[i];
382 if (probs[j] > probs[i]) {
383 float tmp_p = probs[i]; probs[i] = probs[j]; probs[j] = tmp_p;
384 int tmp_i = indices[i]; indices[i] = indices[j]; indices[j] = tmp_i;
388 if (cumsum >= top_p)
break;
392 float r = (float)rand() / (float)RAND_MAX * cumsum;
394 int result = indices[0];
395 for (
int i = 0; cumsum > 0 && i <
vocab_size; i++) {
401 if (acc >= cumsum)
break;
429 if (!
token || max <= 0)
return 0;
431 const unsigned char *src = (
const unsigned char *)
token;
434 while (*src &&
out_len < max - 1) {
435 unsigned int codepoint;
439 if ((src[0] & 0x80) == 0) {
443 }
else if ((src[0] & 0xE0) == 0xC0 && (src[1] & 0xC0) == 0x80) {
445 codepoint = ((src[0] & 0x1F) << 6) | (src[1] & 0x3F);
447 }
else if ((src[0] & 0xF0) == 0xE0 && (src[1] & 0xC0) == 0x80 && (src[2] & 0xC0) == 0x80) {
449 codepoint = ((src[0] & 0x0F) << 12) | ((src[1] & 0x3F) << 6) | (src[2] & 0x3F);
451 }
else if ((src[0] & 0xF8) == 0xF0 && (src[1] & 0xC0) == 0x80 &&
452 (src[2] & 0xC0) == 0x80 && (src[3] & 0xC0) == 0x80) {
454 codepoint = ((src[0] & 0x07) << 18) | ((src[1] & 0x3F) << 12) |
455 ((src[2] & 0x3F) << 6) | (src[3] & 0x3F);
465 if (codepoint >= 0x100 && codepoint <= 0x120) {
467 out[
out_len++] = (char)(codepoint - 0x100);
468 }
else if (codepoint >= 0x17F && codepoint <= 0x1A0) {
470 out[
out_len++] = (char)(codepoint - 0x100);
471 }
else if (codepoint < 0x80) {
473 out[
out_len++] = (char)codepoint;
474 }
else if (codepoint == 0x2581) {
479 for (
int i = 0; i < bytes &&
out_len < max - 1; i++) {
492 if (*len == 0)
return;
493 fwrite(buf, 1, *len, stdout);
499 size_t n = strlen(
text);
504 fwrite(
text, 1, n, stdout);
507 memcpy(buf + *len,
text, n);
526 static bool resolve_symbol(
void *handle,
const char *name,
void **out_ptr,
bool required) {
527 void *sym = dlsym(handle, name);
528 if (!sym && required) {
529 fprintf(stderr,
"Error: missing symbol %s\n", name);
532 if (out_ptr) *out_ptr = sym;
537 if (!lib_path || !api)
return false;
538 memset(api, 0,
sizeof(*api));
539 api->handle = dlopen(lib_path, RTLD_NOW);
541 fprintf(stderr,
"Error: dlopen failed: %s\n", dlerror());
545 if (!
resolve_symbol(api->handle,
"ck_model_init", (
void **)&api->init,
true))
return false;
546 if (!
resolve_symbol(api->handle,
"ck_model_embed_tokens", (
void **)&api->embed,
true))
return false;
547 if (!
resolve_symbol(api->handle,
"ck_model_forward", (
void **)&api->forward,
true))
return false;
548 if (!
resolve_symbol(api->handle,
"ck_model_decode", (
void **)&api->decode,
true))
return false;
549 resolve_symbol(api->handle,
"ck_model_sample_argmax", (
void **)&api->sample,
false);
550 resolve_symbol(api->handle,
"ck_model_get_logits", (
void **)&api->get_logits,
false);
551 resolve_symbol(api->handle,
"ck_model_get_logits_stride", (
void **)&api->get_logits_stride,
false);
552 resolve_symbol(api->handle,
"ck_model_kv_cache_enable", (
void **)&api->kv_enable,
false);
553 resolve_symbol(api->handle,
"ck_model_kv_cache_reset", (
void **)&api->kv_reset,
false);
554 resolve_symbol(api->handle,
"ck_model_get_context_window", (
void **)&api->get_context,
false);
555 resolve_symbol(api->handle,
"ck_model_get_vocab_size", (
void **)&api->get_vocab_size,
false);
556 resolve_symbol(api->handle,
"ck_model_get_num_merges", (
void **)&api->get_num_merges,
false);
557 resolve_symbol(api->handle,
"ck_model_get_vocab_strings_size", (
void **)&api->get_vocab_bytes,
false);
558 resolve_symbol(api->handle,
"ck_model_get_active_tokens", (
void **)&api->get_active_tokens,
false);
559 resolve_symbol(api->handle,
"ck_model_get_vocab_offsets", (
void **)&api->get_offsets,
false);
560 resolve_symbol(api->handle,
"ck_model_get_vocab_strings", (
void **)&api->get_strings,
false);
561 resolve_symbol(api->handle,
"ck_model_get_vocab_merges", (
void **)&api->get_merges,
false);
562 resolve_symbol(api->handle,
"ck_model_free", (
void **)&api->free_fn,
false);
564 if (!api->get_vocab_size || !api->get_offsets || !api->get_strings) {
565 fprintf(stderr,
"Error: vocab accessors missing from model\n");
580 strncpy(lower, model_name,
sizeof(lower) - 1);
581 for (
char *p = lower; *p; p++) *p = (*p >=
'A' && *p <=
'Z') ? *p + 32 : *p;
592 if (system && *system) {
593 needed += strlen(tmpl->system_prefix) + strlen(system) + strlen(tmpl->system_suffix);
595 needed += strlen(tmpl->user_prefix) + strlen(user) + strlen(tmpl->user_suffix);
596 needed += strlen(tmpl->assistant_prefix);
599 char *result = (
char *)malloc(needed);
600 if (!result)
return NULL;
603 if (system && *system) {
604 strcat(result, tmpl->system_prefix);
605 strcat(result, system);
606 strcat(result, tmpl->system_suffix);
608 strcat(result, tmpl->user_prefix);
609 strcat(result, user);
610 strcat(result, tmpl->user_suffix);
611 strcat(result, tmpl->assistant_prefix);
621 if (!opt || opt->ignore_eos)
return false;
622 for (
int i = 0; i < opt->eos_count; i++) {
623 if (opt->eos_ids[i] ==
token)
return true;
637 #define EOS_PATTERN_BUF_SIZE 64
638 #define EOS_PENDING_MAX 8
645 const char *target_pattern;
646 const char *partial_prefix;
654 for (
int i = 0; i <
g_eos_state.pending_count; i++) {
688 size_t tlen = strlen(
token);
693 if (target_len == 0)
return false;
699 memcpy(temp + plen,
token, tlen);
700 temp[plen + tlen] =
'\0';
704 size_t temp_len = plen + tlen;
707 for (
size_t i = 0; i < temp_len; i++) {
708 size_t remaining = temp_len - i;
709 if (remaining > target_len) remaining = target_len;
710 if (strncmp(temp + i, target, remaining) == 0) {
728 void (*output_fn)(
char*,
size_t*,
const char*),
732 if (token_text && output_fn) output_fn(out_buf,
out_len, token_text);
737 size_t tlen = strlen(token_text);
762 for (
int i = 0; i <
g_eos_state.pending_count; i++) {
771 if (output_fn) output_fn(out_buf,
out_len, token_text);
776 if (!arg || !opt)
return false;
781 long v = strtol(p, &
end, 10);
783 opt->eos_ids[opt->eos_count++] = (int)v;
787 return opt->eos_count > 0;
794 static int run_prompt(ModelAPI *api, CKTrueBPE *tokenizer, CLIOptions *opt,
const char *input) {
795 if (!api || !tokenizer || !opt || !input)
return -1;
798 int ctx = opt->context_override;
799 if (ctx <= 0 && api->get_context) ctx = api->get_context();
800 if (ctx <= 0) ctx = 4096;
809 fprintf(stderr,
"Error: failed to format prompt\n");
814 printf(
"[DEBUG] Formatted prompt:\n%s\n", formatted);
817 int32_t *
ids = (int32_t *)malloc((
size_t)ctx *
sizeof(int32_t));
819 fprintf(stderr,
"Error: failed to allocate token buffer\n");
828 fprintf(stderr,
"[Tokenizer] failed to encode prompt\n");
832 if (n > ctx - max_tokens) {
833 n = ctx - max_tokens;
835 printf(
"[DEBUG] Truncated prompt to %d tokens\n", n);
844 if (api->kv_reset) api->kv_reset();
846 if (api->embed(
ids, n) != 0) {
847 fprintf(stderr,
"[Model] embed failed\n");
852 struct timespec t0, t1;
853 clock_gettime(CLOCK_MONOTONIC, &t0);
854 if (api->forward(NULL) != 0) {
855 fprintf(stderr,
"[Model] forward failed\n");
859 clock_gettime(CLOCK_MONOTONIC, &t1);
861 (t1.tv_nsec - t0.tv_nsec) / 1000000.0;
864 int vocab_size = api->get_vocab_size ? api->get_vocab_size() : 0;
867 #define SAMPLE_NEXT_TOKEN() do { \
868 if (api->get_logits && vocab_size > 0) { \
869 float *logits = api->get_logits(); \
871 int stride = api->get_logits_stride ? api->get_logits_stride() : vocab_size; \
872 int active = api->get_active_tokens ? api->get_active_tokens() : 1; \
873 float *last_logits = logits; \
875 if (active < 1) active = 1; \
876 last_logits = logits + (size_t)(active - 1) * (size_t)stride; \
878 float *logits_copy = (float *)malloc(vocab_size * sizeof(float)); \
879 memcpy(logits_copy, last_logits, vocab_size * sizeof(float)); \
880 next_token = sample_top_p(logits_copy, vocab_size, opt->temperature, opt->top_p); \
882 } else if (api->sample) { \
883 next_token = api->sample(); \
887 } else if (api->sample) { \
888 next_token = api->sample(); \
907 if (next_token < 0)
break;
911 fprintf(stderr,
"[DEBUG] Token %d: %d (%s)\n", generated, next_token, tok_str ? tok_str :
"NULL");
916 fprintf(stderr,
"[DEBUG] EOS detected (token ID), stopping\n");
924 if (!opt->ignore_eos &&
927 fprintf(stderr,
"[DEBUG] EOS detected (text pattern), stopping\n");
940 if (generated + 1 >= max_tokens)
break;
942 clock_gettime(CLOCK_MONOTONIC, &t0);
943 if (api->decode(next_token, NULL) != 0) {
944 fprintf(stderr,
"\n[Model] decode failed\n");
947 clock_gettime(CLOCK_MONOTONIC, &t1);
949 (t1.tv_nsec - t0.tv_nsec) / 1000000.0;
956 #undef SAMPLE_NEXT_TOKEN
969 printf(
"decode: %3d tok / %7.1f ms (%5.1f tok/s, %5.1f ms/tok)\033[0m\n",
984 printf(
" \033[1;36mC-Kernel-Engine v%s\033[0m\n",
CK_CLI_VERSION);
985 printf(
" Native inference CLI with true-BPE tokenization\n");
991 fprintf(stderr,
"Usage:\n");
992 fprintf(stderr,
" %s --model <name> Auto-discover model from cache\n", prog);
993 fprintf(stderr,
" %s <libmodel.so> <weights.bump> Direct paths\n", prog);
994 fprintf(stderr,
" %s --lib <.so> --weights <.bump> Named arguments\n", prog);
995 fprintf(stderr,
"\nOptions:\n");
996 fprintf(stderr,
" --model, -m NAME Model name (searches in cache)\n");
997 fprintf(stderr,
" --lib PATH Path to compiled model .so\n");
998 fprintf(stderr,
" --weights PATH Path to weights .bump file\n");
999 fprintf(stderr,
" --prompt, -p TEXT Run single prompt (non-interactive)\n");
1000 fprintf(stderr,
" --system, -S TEXT System prompt\n");
1002 fprintf(stderr,
" --context, -c N Override context/KV cache size\n");
1003 fprintf(stderr,
" --temperature, -T F Sampling temperature (default: 0.0 = greedy)\n");
1004 fprintf(stderr,
" --top-p F Nucleus sampling top-p (default: 0.9)\n");
1005 fprintf(stderr,
" --stream, -s Stream tokens as generated\n");
1006 fprintf(stderr,
" --timing, -t Show timing breakdown\n");
1007 fprintf(stderr,
" --no-chat-template Disable chat template formatting\n");
1008 fprintf(stderr,
" --eos IDS Comma-separated EOS token IDs\n");
1009 fprintf(stderr,
" --ignore-eos Do not stop on EOS tokens\n");
1010 fprintf(stderr,
" --list List available models\n");
1011 fprintf(stderr,
" --verbose, -v Verbose output\n");
1012 fprintf(stderr,
" --help, -h Show this help\n");
1013 fprintf(stderr,
"\nREPL Commands:\n");
1014 fprintf(stderr,
" /exit, /quit Exit the REPL\n");
1015 fprintf(stderr,
" /reset Reset KV cache\n");
1016 fprintf(stderr,
" /timing Toggle timing display\n");
1017 fprintf(stderr,
" /temp <value> Set temperature\n");
1018 fprintf(stderr,
" /system <text> Set system prompt\n");
1019 fprintf(stderr,
" /help Show help\n");
1023 if (!opt)
return false;
1024 memset(opt, 0,
sizeof(*opt));
1026 opt->temperature = 0.0f;
1031 opt->eos_ids[0] = 151643;
1032 opt->eos_ids[1] = 151645;
1033 opt->eos_ids[2] = 151644;
1036 for (
int i = 1; i < argc; i++) {
1037 const char *arg = argv[i];
1039 if (!strcmp(arg,
"--help") || !strcmp(arg,
"-h")) {
1042 }
else if (!strcmp(arg,
"--list")) {
1045 }
else if ((!strcmp(arg,
"--model") || !strcmp(arg,
"-m")) && i + 1 < argc) {
1046 opt->model_name = argv[++i];
1047 }
else if (!strcmp(arg,
"--lib") && i + 1 < argc) {
1048 opt->lib_path = argv[++i];
1049 }
else if (!strcmp(arg,
"--weights") && i + 1 < argc) {
1050 opt->weights_path = argv[++i];
1051 }
else if ((!strcmp(arg,
"--prompt") || !strcmp(arg,
"-p")) && i + 1 < argc) {
1052 opt->prompt_once = argv[++i];
1053 }
else if ((!strcmp(arg,
"--system") || !strcmp(arg,
"-S")) && i + 1 < argc) {
1054 opt->system_prompt = argv[++i];
1055 }
else if ((!strcmp(arg,
"--max-tokens") || !strcmp(arg,
"-n")) && i + 1 < argc) {
1056 opt->max_tokens = atoi(argv[++i]);
1057 }
else if ((!strcmp(arg,
"--context") || !strcmp(arg,
"-c")) && i + 1 < argc) {
1058 opt->context_override = atoi(argv[++i]);
1059 }
else if ((!strcmp(arg,
"--temperature") || !strcmp(arg,
"-T")) && i + 1 < argc) {
1060 opt->temperature = (float)atof(argv[++i]);
1061 }
else if (!strcmp(arg,
"--top-p") && i + 1 < argc) {
1062 opt->top_p = (float)atof(argv[++i]);
1063 }
else if (!strcmp(arg,
"--stream") || !strcmp(arg,
"-s")) {
1065 }
else if (!strcmp(arg,
"--no-stream")) {
1066 opt->stream =
false;
1067 }
else if (!strcmp(arg,
"--timing") || !strcmp(arg,
"-t")) {
1069 }
else if (!strcmp(arg,
"--no-timing")) {
1070 opt->timing =
false;
1071 }
else if (!strcmp(arg,
"--no-chat-template")) {
1072 opt->no_chat_template =
true;
1073 }
else if (!strcmp(arg,
"--eos") && i + 1 < argc) {
1075 }
else if (!strcmp(arg,
"--ignore-eos")) {
1076 opt->ignore_eos =
true;
1077 }
else if (!strcmp(arg,
"--verbose") || !strcmp(arg,
"-v")) {
1078 opt->verbose =
true;
1079 }
else if (arg[0] !=
'-') {
1080 if (!opt->lib_path) opt->lib_path = arg;
1081 else if (!opt->weights_path) opt->weights_path = arg;
1083 fprintf(stderr,
"Unknown argument: %s\n", arg);
1087 fprintf(stderr,
"Unknown option: %s\n", arg);
1093 if (opt->model_name && (!opt->lib_path || !opt->weights_path)) {
1094 static char lib_buf[4096], weights_buf[4096];
1096 opt->lib_path = lib_buf;
1097 opt->weights_path = weights_buf;
1099 fprintf(stderr,
"Error: model '%s' not found in cache\n", opt->model_name);
1100 fprintf(stderr,
"Run with --list to see available models\n");
1105 if (!opt->lib_path || !opt->weights_path) {
1111 const char *name_for_template = opt->model_name ? opt->model_name : opt->lib_path;
1117 printf(
"[DEBUG] Loaded %d EOS tokens: ", opt->eos_count);
1118 for (
int i = 0; i < opt->eos_count; i++) {
1119 printf(
"%d ", opt->eos_ids[i]);
1133 if (!line || line[0] !=
'/')
return false;
1135 if (!strncmp(line,
"/exit", 5) || !strncmp(line,
"/quit", 5)) {
1139 if (!strncmp(line,
"/help", 5)) {
1140 printf(
"REPL Commands:\n");
1141 printf(
" /exit, /quit Exit\n");
1142 printf(
" /reset Reset KV cache\n");
1143 printf(
" /timing Toggle timing display\n");
1144 printf(
" /temp <value> Set temperature (0 = greedy)\n");
1145 printf(
" /top-p <value> Set top-p\n");
1146 printf(
" /system <text> Set system prompt\n");
1147 printf(
" /clear Clear system prompt\n");
1148 printf(
" /verbose Toggle verbose mode\n");
1151 if (!strncmp(line,
"/reset", 6)) {
1152 if (api->kv_reset) {
1154 printf(
"[KV cache reset]\n");
1158 if (!strncmp(line,
"/timing", 7)) {
1159 opt->timing = !opt->timing;
1160 printf(
"[Timing %s]\n", opt->timing ?
"enabled" :
"disabled");
1163 if (!strncmp(line,
"/verbose", 8)) {
1164 opt->verbose = !opt->verbose;
1165 printf(
"[Verbose %s]\n", opt->verbose ?
"enabled" :
"disabled");
1168 if (!strncmp(line,
"/temp ", 6)) {
1169 opt->temperature = (float)atof(line + 6);
1170 printf(
"[Temperature set to %.2f]\n", opt->temperature);
1173 if (!strncmp(line,
"/top-p ", 7)) {
1174 opt->top_p = (float)atof(line + 7);
1175 printf(
"[Top-p set to %.2f]\n", opt->top_p);
1178 if (!strncmp(line,
"/system ", 8)) {
1179 opt->system_prompt = strdup(line + 8);
1180 printf(
"[System prompt set]\n");
1183 if (!strncmp(line,
"/clear", 6)) {
1184 opt->system_prompt = NULL;
1185 printf(
"[System prompt cleared]\n");
1189 printf(
"Unknown command: %s\n", line);
1199 srand((
unsigned int)time(NULL));
1207 printf(
"Loading: %s\n", opt.lib_path);
1214 printf(
"Initializing model...\n");
1215 if (api.init(opt.weights_path) != 0) {
1216 fprintf(stderr,
"Error: ck_model_init failed\n");
1220 int ctx = opt.context_override;
1221 if (ctx <= 0 && api.get_context) ctx = api.get_context();
1222 if (api.kv_enable && ctx > 0) {
1228 fprintf(stderr,
"[Tokenizer] failed to create\n");
1232 int vocab_size = api.get_vocab_size ? api.get_vocab_size() : 0;
1233 int vocab_bytes = api.get_vocab_bytes ? api.get_vocab_bytes() : 0;
1234 int num_merges = api.get_num_merges ? api.get_num_merges() : 0;
1235 const int32_t *
offsets = (
const int32_t *)api.get_offsets();
1236 const char *
strings = (
const char *)api.get_strings();
1237 const int32_t *
merges = api.get_merges ? (
const int32_t *)api.get_merges() : NULL;
1240 fprintf(stderr,
"[Tokenizer] missing vocab data in model\n");
1246 fprintf(stderr,
"[Tokenizer] failed to load vocab\n");
1258 static const char *special_tokens[] = {
1260 "<|im_start|>",
"<|im_end|>",
"<|endoftext|>",
1262 "<|eot_id|>",
"<|begin_of_text|>",
"<|end_of_text|>",
1263 "<|start_header_id|>",
"<|end_header_id|>",
1265 "</s>",
"<s>",
"<pad>",
"<unk>",
1269 for (
int i = 0; special_tokens[i] != NULL; i++) {
1273 if (check && strcmp(check, special_tokens[i]) == 0) {
1277 printf(
"[Tokenizer] Registered special: %s -> %d\n", special_tokens[i],
id);
1282 printf(
"[Tokenizer] Registered %d special tokens for pre-BPE matching\n", registered);
1286 printf(
"Ready! Vocab: %d, Context: %d, Template: %s\n",
1288 opt.no_chat_template ?
"none" :
1295 printf(
"[Hardware] %s | Vector: %d-bit | FMA: %s | AI Accel: %s | Kernel: %s\n",
1299 printf(
"Type /help for commands, Ctrl+C to stop generation\n\n");
1301 setvbuf(stdout, NULL, _IOFBF, 1 << 20);
1303 if (opt.prompt_once) {
1304 run_prompt(&api, tokenizer, &opt, opt.prompt_once);
1307 #ifdef HAVE_READLINE
1308 char *home = getenv(
"HOME");
1309 char history_path[4096];
1312 read_history(history_path);
1317 #ifdef HAVE_READLINE
1318 char *line = readline(
"\033[1;32mYou:\033[0m ");
1320 if (*line) add_history(line);
1322 printf(
"\033[1;32mYou:\033[0m ");
1324 char line_buf[4096];
1325 if (!fgets(line_buf,
sizeof(line_buf), stdin)) {
1327 if (errno == EINTR)
break;
1331 size_t len = strlen(line_buf);
1332 if (len > 0 && line_buf[len-1] ==
'\n') line_buf[len-1] =
'\0';
1333 char *line = line_buf;
1336 if (line[0] ==
'\0') {
1337 #ifdef HAVE_READLINE
1343 if (line[0] ==
'/') {
1345 #ifdef HAVE_READLINE
1351 printf(
"\033[1;34mAssistant:\033[0m ");
1355 #ifdef HAVE_READLINE
1360 #ifdef HAVE_READLINE
1362 write_history(history_path);
1368 if (api.free_fn) api.free_fn();
1369 if (api.handle) dlclose(api.handle);
1371 printf(
"\nGoodbye!\n");
void *(* get_ptr_t)(void)
static bool parse_eos_ids(const char *arg, CLIOptions *opt)
int(* init_t)(const char *weights_path)
static bool resolve_symbol(void *handle, const char *name, void **out_ptr, bool required)
static double g_decode_time_ms
static void handle_sigint(int sig)
int(* embed_t)(const int32_t *tokens, int num_tokens)
static char * apply_chat_template(const ChatTemplate *tmpl, const char *system, const char *user)
int(* kv_enable_t)(int capacity)
static int sample_top_p(float *logits, int vocab_size, float temperature, float top_p)
static bool load_eos_from_vocab_json(const char *weights_path, CLIOptions *opt)
static double g_prefill_time_ms
static bool find_model_in_cache(const char *model_name, char *lib_out, char *weights_out, size_t out_size)
int main(int argc, char **argv)
static void print_help(const char *prog)
static ChatTemplateType detect_chat_template(const char *model_name)
#define CK_CLI_HISTORY_FILE
static int g_decode_count
static bool process_repl_command(const char *line, CLIOptions *opt, ModelAPI *api)
static bool is_eos_token(const CLIOptions *opt, int token)
static bool eos_is_potential_prefix(const char *token)
int(* forward_t)(float *logits_out)
float *(* get_logits_t)(void)
static void eos_pattern_init(ChatTemplateType tmpl)
static void output_append(char *buf, size_t *len, const char *text)
static void list_available_models(void)
static volatile sig_atomic_t g_generation_active
#define CK_CLI_MAX_CONTEXT
static int decode_bpe_token(const char *token, char *out, int max)
int(* decode_t)(int32_t token, float *logits_out)
#define EOS_PATTERN_BUF_SIZE
static void print_banner(void)
static bool parse_args(int argc, char **argv, CLIOptions *opt)
static int run_prompt(ModelAPI *api, CKTrueBPE *tokenizer, CLIOptions *opt, const char *input)
static void eos_pattern_reset(void)
static volatile sig_atomic_t g_exit_requested
static EOSPatternState g_eos_state
#define CK_CLI_OUTPUT_BUF_SIZE
#define CK_CLI_DEFAULT_MAX_TOKENS
static void output_flush(char *buf, size_t *len)
static const ChatTemplate g_templates[]
static const char * get_cache_dir(void)
static bool load_model_api(const char *lib_path, ModelAPI *api)
#define SAMPLE_NEXT_TOKEN()
static void output_token(char *buf, size_t *len, const char *token)
static bool eos_pattern_process(const char *token_text, char *out_buf, size_t *out_len, void(*output_fn)(char *, size_t *, const char *), ChatTemplateType tmpl)
int(* sample_argmax_t)(void)
static int g_prompt_tokens
CPU feature detection and dispatch macros.
static ck_capability_t ck_get_capabilities(void)
Get current platform capabilities.
CPU capability information structure.
int32_t int32_t int32_t eos
const int32_t int int * out_len
int ck_true_bpe_encode(CKTrueBPE *bpe, const char *text, int text_len, int32_t *ids, int max_ids)
void ck_true_bpe_free(CKTrueBPE *bpe)
CKTrueBPE * ck_true_bpe_create(void)
int ck_true_bpe_add_special_token(CKTrueBPE *bpe, const char *token, int32_t id)
int ck_true_bpe_load_binary(CKTrueBPE *bpe, int vocab_size, const int32_t *offsets, const char *strings, int num_merges, const int32_t *merges)
int32_t ck_true_bpe_lookup(const CKTrueBPE *bpe, const char *token)
const char * ck_true_bpe_id_to_token(const CKTrueBPE *bpe, int32_t id)
int const int32_t const char int num_merges
int const int32_t const char * strings
int const int32_t const char int const int32_t * merges
int const int32_t * offsets