33 #include <readline/readline.h>
34 #include <readline/history.h>
40 #define CK_CLI_VERSION "6.5.0"
41 #define CK_CLI_DEFAULT_MAX_TOKENS 256
42 #define CK_CLI_EOS_MAX 8
43 #define CK_CLI_OUTPUT_BUF_SIZE 4096
44 #define CK_CLI_MAX_CONTEXT 32768
45 #define CK_CLI_HISTORY_FILE ".ck_cli_history"
69 typedef int (*
init_t)(
const char *weights_path);
70 typedef int (*
embed_t)(
const int32_t *tokens,
int num_tokens);
76 typedef float *(*get_logits_t)(void);
78 typedef void *(*get_ptr_t)(void);
116 const char *system_prefix;
117 const char *system_suffix;
118 const char *user_prefix;
119 const char *user_suffix;
120 const char *assistant_prefix;
121 const char *assistant_suffix;
127 .system_prefix =
"", .system_suffix =
"\n",
128 .user_prefix =
"", .user_suffix =
"\n",
129 .assistant_prefix =
"", .assistant_suffix =
"",
133 .system_prefix =
"<|im_start|>system\n",
134 .system_suffix =
"<|im_end|>\n",
135 .user_prefix =
"<|im_start|>user\n",
136 .user_suffix =
"<|im_end|>\n",
137 .assistant_prefix =
"<|im_start|>assistant\n",
138 .assistant_suffix =
"<|im_end|>",
142 .system_prefix =
"[INST] <<SYS>>\n",
143 .system_suffix =
"\n<</SYS>>\n\n",
145 .user_suffix =
" [/INST]",
146 .assistant_prefix =
" ",
147 .assistant_suffix =
" </s><s>[INST] ",
151 .system_prefix =
"<|im_start|>system\n",
152 .system_suffix =
"<|im_end|>\n",
153 .user_prefix =
"<|im_start|>user\n",
154 .user_suffix =
"<|im_end|>\n",
155 .assistant_prefix =
"<|im_start|>assistant\n",
156 .assistant_suffix =
"<|im_end|>",
161 .system_suffix =
"\n\n",
162 .user_prefix =
"[INST] ",
163 .user_suffix =
" [/INST]",
164 .assistant_prefix =
"",
165 .assistant_suffix =
"</s> ",
174 const char *model_name;
175 const char *lib_path;
176 const char *weights_path;
177 const char *prompt_once;
178 const char *system_prompt;
180 int context_override;
187 bool no_chat_template;
198 static char cache_path[4096];
199 const char *home = getenv(
"HOME");
200 if (!home) home =
"/tmp";
201 snprintf(cache_path,
sizeof(cache_path),
"%s/.cache/ck-engine-v6.5/models", home);
205 static bool find_model_in_cache(
const char *model_name,
char *lib_out,
char *weights_out,
size_t out_size) {
207 DIR *dir = opendir(cache_dir);
208 if (!dir)
return false;
210 struct dirent *entry;
211 while ((entry = readdir(dir)) != NULL) {
212 if (entry->d_name[0] ==
'.')
continue;
215 if (strstr(entry->d_name, model_name) != NULL) {
216 char model_dir[4096];
217 snprintf(model_dir,
sizeof(model_dir),
"%s/%s", cache_dir, entry->d_name);
220 char so_path[4096], bump_path[4096];
221 snprintf(so_path,
sizeof(so_path),
"%s/ck-kernel-inference.so", model_dir);
222 snprintf(bump_path,
sizeof(bump_path),
"%s/weights.bump", model_dir);
225 if (stat(so_path, &st) == 0 && stat(bump_path, &st) == 0) {
226 strncpy(lib_out, so_path, out_size - 1);
227 strncpy(weights_out, bump_path, out_size - 1);
242 if (!weights_path || !opt)
return false;
245 char vocab_path[4096];
246 const char *slash = strrchr(weights_path,
'/');
247 if (!slash)
return false;
249 size_t dir_len = (size_t)(slash - weights_path);
250 if (dir_len + 12 >=
sizeof(vocab_path))
return false;
252 memcpy(vocab_path, weights_path, dir_len);
253 vocab_path[dir_len] =
'\0';
254 strcat(vocab_path,
"/vocab.json");
256 FILE *f = fopen(vocab_path,
"r");
257 if (!f)
return false;
261 size_t n = fread(buf, 1,
sizeof(buf) - 1, f);
266 const char *st = strstr(buf,
"\"special_tokens\"");
267 if (!st)
return false;
270 const char *
eos = strstr(st,
"\"eos\"");
272 const char *colon = strchr(
eos,
':');
274 int eos_id = atoi(colon + 1);
276 opt->eos_ids[0] = eos_id;
283 const char *
bos = strstr(st,
"\"bos\"");
285 const char *colon = strchr(
bos,
':');
287 int bos_id = atoi(colon + 1);
288 if (bos_id > 0 && bos_id != opt->eos_ids[0]) {
289 opt->eos_ids[opt->eos_count++] = bos_id;
294 return opt->eos_count > 0;
299 DIR *dir = opendir(cache_dir);
301 fprintf(stderr,
"No models found in %s\n", cache_dir);
305 printf(
"Available models in %s:\n", cache_dir);
306 struct dirent *entry;
308 while ((entry = readdir(dir)) != NULL) {
309 if (entry->d_name[0] ==
'.')
continue;
311 char model_dir[4096];
312 snprintf(model_dir,
sizeof(model_dir),
"%s/%s", cache_dir, entry->d_name);
315 snprintf(so_path,
sizeof(so_path),
"%s/ck-kernel-inference.so", model_dir);
318 if (stat(so_path, &st) == 0) {
319 printf(
" - %s\n", entry->d_name);
326 printf(
" (none found)\n");
335 if (temperature <= 0.0f || top_p <= 0.0f) {
338 float best_val = logits[0];
340 if (logits[i] > best_val) {
341 best_val = logits[i];
349 float max_logit = logits[0];
351 if (logits[i] > max_logit) max_logit = logits[i];
356 logits[i] = expf((logits[i] - max_logit) / temperature);
368 float threshold = (float)rand() / (float)RAND_MAX * top_p;
371 int *indices = (
int *)malloc(
vocab_size *
sizeof(
int));
372 float *probs = (
float *)malloc(
vocab_size *
sizeof(
float));
375 probs[i] = logits[i];
381 if (probs[j] > probs[i]) {
382 float tmp_p = probs[i]; probs[i] = probs[j]; probs[j] = tmp_p;
383 int tmp_i = indices[i]; indices[i] = indices[j]; indices[j] = tmp_i;
387 if (cumsum >= top_p)
break;
391 float r = (float)rand() / (float)RAND_MAX * cumsum;
393 int result = indices[0];
394 for (
int i = 0; cumsum > 0 && i <
vocab_size; i++) {
400 if (acc >= cumsum)
break;
428 if (!
token || max <= 0)
return 0;
430 const unsigned char *src = (
const unsigned char *)
token;
433 while (*src &&
out_len < max - 1) {
434 unsigned int codepoint;
438 if ((src[0] & 0x80) == 0) {
442 }
else if ((src[0] & 0xE0) == 0xC0 && (src[1] & 0xC0) == 0x80) {
444 codepoint = ((src[0] & 0x1F) << 6) | (src[1] & 0x3F);
446 }
else if ((src[0] & 0xF0) == 0xE0 && (src[1] & 0xC0) == 0x80 && (src[2] & 0xC0) == 0x80) {
448 codepoint = ((src[0] & 0x0F) << 12) | ((src[1] & 0x3F) << 6) | (src[2] & 0x3F);
450 }
else if ((src[0] & 0xF8) == 0xF0 && (src[1] & 0xC0) == 0x80 &&
451 (src[2] & 0xC0) == 0x80 && (src[3] & 0xC0) == 0x80) {
453 codepoint = ((src[0] & 0x07) << 18) | ((src[1] & 0x3F) << 12) |
454 ((src[2] & 0x3F) << 6) | (src[3] & 0x3F);
464 if (codepoint >= 0x100 && codepoint <= 0x120) {
466 out[
out_len++] = (char)(codepoint - 0x100);
467 }
else if (codepoint >= 0x17F && codepoint <= 0x1A0) {
469 out[
out_len++] = (char)(codepoint - 0x100);
470 }
else if (codepoint < 0x80) {
472 out[
out_len++] = (char)codepoint;
473 }
else if (codepoint == 0x2581) {
478 for (
int i = 0; i < bytes &&
out_len < max - 1; i++) {
491 if (*len == 0)
return;
492 fwrite(buf, 1, *len, stdout);
498 size_t n = strlen(
text);
503 fwrite(
text, 1, n, stdout);
506 memcpy(buf + *len,
text, n);
525 static bool resolve_symbol(
void *handle,
const char *name,
void **out_ptr,
bool required) {
526 void *sym = dlsym(handle, name);
527 if (!sym && required) {
528 fprintf(stderr,
"Error: missing symbol %s\n", name);
531 if (out_ptr) *out_ptr = sym;
536 if (!lib_path || !api)
return false;
537 memset(api, 0,
sizeof(*api));
538 api->handle = dlopen(lib_path, RTLD_NOW);
540 fprintf(stderr,
"Error: dlopen failed: %s\n", dlerror());
544 if (!
resolve_symbol(api->handle,
"ck_model_init", (
void **)&api->init,
true))
return false;
545 if (!
resolve_symbol(api->handle,
"ck_model_embed_tokens", (
void **)&api->embed,
true))
return false;
546 if (!
resolve_symbol(api->handle,
"ck_model_forward", (
void **)&api->forward,
true))
return false;
547 if (!
resolve_symbol(api->handle,
"ck_model_decode", (
void **)&api->decode,
true))
return false;
548 if (!
resolve_symbol(api->handle,
"ck_model_sample_argmax", (
void **)&api->sample,
true))
return false;
549 resolve_symbol(api->handle,
"ck_model_get_logits", (
void **)&api->get_logits,
false);
550 resolve_symbol(api->handle,
"ck_model_kv_cache_enable", (
void **)&api->kv_enable,
false);
551 resolve_symbol(api->handle,
"ck_model_kv_cache_reset", (
void **)&api->kv_reset,
false);
552 resolve_symbol(api->handle,
"ck_model_get_context_window", (
void **)&api->get_context,
false);
553 resolve_symbol(api->handle,
"ck_model_get_vocab_size", (
void **)&api->get_vocab_size,
false);
554 resolve_symbol(api->handle,
"ck_model_get_num_merges", (
void **)&api->get_num_merges,
false);
555 resolve_symbol(api->handle,
"ck_model_get_vocab_strings_size", (
void **)&api->get_vocab_bytes,
false);
556 resolve_symbol(api->handle,
"ck_model_get_active_tokens", (
void **)&api->get_active_tokens,
false);
557 resolve_symbol(api->handle,
"ck_model_get_vocab_offsets", (
void **)&api->get_offsets,
false);
558 resolve_symbol(api->handle,
"ck_model_get_vocab_strings", (
void **)&api->get_strings,
false);
559 resolve_symbol(api->handle,
"ck_model_get_vocab_merges", (
void **)&api->get_merges,
false);
560 resolve_symbol(api->handle,
"ck_model_free", (
void **)&api->free_fn,
false);
562 if (!api->get_vocab_size || !api->get_vocab_bytes || !api->get_offsets || !api->get_strings) {
563 fprintf(stderr,
"Error: vocab accessors missing from model\n");
578 strncpy(lower, model_name,
sizeof(lower) - 1);
579 for (
char *p = lower; *p; p++) *p = (*p >=
'A' && *p <=
'Z') ? *p + 32 : *p;
590 if (system && *system) {
591 needed += strlen(tmpl->system_prefix) + strlen(system) + strlen(tmpl->system_suffix);
593 needed += strlen(tmpl->user_prefix) + strlen(user) + strlen(tmpl->user_suffix);
594 needed += strlen(tmpl->assistant_prefix);
597 char *result = (
char *)malloc(needed);
598 if (!result)
return NULL;
601 if (system && *system) {
602 strcat(result, tmpl->system_prefix);
603 strcat(result, system);
604 strcat(result, tmpl->system_suffix);
606 strcat(result, tmpl->user_prefix);
607 strcat(result, user);
608 strcat(result, tmpl->user_suffix);
609 strcat(result, tmpl->assistant_prefix);
619 if (!opt || opt->ignore_eos)
return false;
620 for (
int i = 0; i < opt->eos_count; i++) {
621 if (opt->eos_ids[i] ==
token)
return true;
635 #define EOS_PATTERN_BUF_SIZE 64
636 #define EOS_PENDING_MAX 8
643 const char *target_pattern;
644 const char *partial_prefix;
652 for (
int i = 0; i <
g_eos_state.pending_count; i++) {
686 size_t tlen = strlen(
token);
691 if (target_len == 0)
return false;
697 memcpy(temp + plen,
token, tlen);
698 temp[plen + tlen] =
'\0';
702 size_t temp_len = plen + tlen;
705 for (
size_t i = 0; i < temp_len; i++) {
706 size_t remaining = temp_len - i;
707 if (remaining > target_len) remaining = target_len;
708 if (strncmp(temp + i, target, remaining) == 0) {
726 void (*output_fn)(
char*,
size_t*,
const char*),
730 if (token_text && output_fn) output_fn(out_buf,
out_len, token_text);
735 size_t tlen = strlen(token_text);
760 for (
int i = 0; i <
g_eos_state.pending_count; i++) {
769 if (output_fn) output_fn(out_buf,
out_len, token_text);
774 if (!arg || !opt)
return false;
779 long v = strtol(p, &
end, 10);
781 opt->eos_ids[opt->eos_count++] = (int)v;
785 return opt->eos_count > 0;
792 static int run_prompt(ModelAPI *api, CKTrueBPE *tokenizer, CLIOptions *opt,
const char *input) {
793 if (!api || !tokenizer || !opt || !input)
return -1;
796 int ctx = opt->context_override;
797 if (ctx <= 0 && api->get_context) ctx = api->get_context();
798 if (ctx <= 0) ctx = 4096;
807 fprintf(stderr,
"Error: failed to format prompt\n");
812 printf(
"[DEBUG] Formatted prompt:\n%s\n", formatted);
815 int32_t *
ids = (int32_t *)malloc((
size_t)ctx *
sizeof(int32_t));
817 fprintf(stderr,
"Error: failed to allocate token buffer\n");
826 fprintf(stderr,
"[Tokenizer] failed to encode prompt\n");
830 if (n > ctx - max_tokens) {
831 n = ctx - max_tokens;
833 printf(
"[DEBUG] Truncated prompt to %d tokens\n", n);
842 if (api->kv_reset) api->kv_reset();
844 if (api->embed(
ids, n) != 0) {
845 fprintf(stderr,
"[Model] embed failed\n");
850 struct timespec t0, t1;
851 clock_gettime(CLOCK_MONOTONIC, &t0);
852 if (api->forward(NULL) != 0) {
853 fprintf(stderr,
"[Model] forward failed\n");
857 clock_gettime(CLOCK_MONOTONIC, &t1);
859 (t1.tv_nsec - t0.tv_nsec) / 1000000.0;
862 int vocab_size = api->get_vocab_size ? api->get_vocab_size() : 0;
866 if (opt->temperature > 0.0f && api->get_logits &&
vocab_size > 0) {
867 float *logits = api->get_logits();
870 int active = api->get_active_tokens ? api->get_active_tokens() : 1;
871 float *last_logits = logits + (size_t)(active - 1) *
vocab_size;
873 float *logits_copy = (
float *)malloc(
vocab_size *
sizeof(
float));
874 memcpy(logits_copy, last_logits,
vocab_size *
sizeof(
float));
878 next_token = api->sample();
881 next_token = api->sample();
893 if (next_token < 0)
break;
897 fprintf(stderr,
"[DEBUG] Token %d: %d (%s)\n", generated, next_token, tok_str ? tok_str :
"NULL");
902 fprintf(stderr,
"[DEBUG] EOS detected (token ID), stopping\n");
910 if (!opt->ignore_eos &&
913 fprintf(stderr,
"[DEBUG] EOS detected (text pattern), stopping\n");
926 if (generated + 1 >= max_tokens)
break;
928 clock_gettime(CLOCK_MONOTONIC, &t0);
929 if (api->decode(next_token, NULL) != 0) {
930 fprintf(stderr,
"\n[Model] decode failed\n");
933 clock_gettime(CLOCK_MONOTONIC, &t1);
935 (t1.tv_nsec - t0.tv_nsec) / 1000000.0;
939 if (opt->temperature > 0.0f && api->get_logits &&
vocab_size > 0) {
940 float *logits = api->get_logits();
942 int active = api->get_active_tokens ? api->get_active_tokens() : 1;
943 float *last_logits = logits + (size_t)(active - 1) *
vocab_size;
944 float *logits_copy = (
float *)malloc(
vocab_size *
sizeof(
float));
945 memcpy(logits_copy, last_logits,
vocab_size *
sizeof(
float));
949 next_token = api->sample();
952 next_token = api->sample();
968 printf(
"decode: %3d tok / %7.1f ms (%5.1f tok/s, %5.1f ms/tok)\033[0m\n",
983 printf(
" \033[1;36mC-Kernel-Engine v%s\033[0m\n",
CK_CLI_VERSION);
984 printf(
" Native inference CLI with true-BPE tokenization\n");
990 fprintf(stderr,
"Usage:\n");
991 fprintf(stderr,
" %s --model <name> Auto-discover model from cache\n", prog);
992 fprintf(stderr,
" %s <libmodel.so> <weights.bump> Direct paths\n", prog);
993 fprintf(stderr,
" %s --lib <.so> --weights <.bump> Named arguments\n", prog);
994 fprintf(stderr,
"\nOptions:\n");
995 fprintf(stderr,
" --model, -m NAME Model name (searches in cache)\n");
996 fprintf(stderr,
" --lib PATH Path to compiled model .so\n");
997 fprintf(stderr,
" --weights PATH Path to weights .bump file\n");
998 fprintf(stderr,
" --prompt, -p TEXT Run single prompt (non-interactive)\n");
999 fprintf(stderr,
" --system, -S TEXT System prompt\n");
1001 fprintf(stderr,
" --context, -c N Override context/KV cache size\n");
1002 fprintf(stderr,
" --temperature, -T F Sampling temperature (default: 0.0 = greedy)\n");
1003 fprintf(stderr,
" --top-p F Nucleus sampling top-p (default: 0.9)\n");
1004 fprintf(stderr,
" --stream, -s Stream tokens as generated\n");
1005 fprintf(stderr,
" --timing, -t Show timing breakdown\n");
1006 fprintf(stderr,
" --no-chat-template Disable chat template formatting\n");
1007 fprintf(stderr,
" --eos IDS Comma-separated EOS token IDs\n");
1008 fprintf(stderr,
" --ignore-eos Do not stop on EOS tokens\n");
1009 fprintf(stderr,
" --list List available models\n");
1010 fprintf(stderr,
" --verbose, -v Verbose output\n");
1011 fprintf(stderr,
" --help, -h Show this help\n");
1012 fprintf(stderr,
"\nREPL Commands:\n");
1013 fprintf(stderr,
" /exit, /quit Exit the REPL\n");
1014 fprintf(stderr,
" /reset Reset KV cache\n");
1015 fprintf(stderr,
" /timing Toggle timing display\n");
1016 fprintf(stderr,
" /temp <value> Set temperature\n");
1017 fprintf(stderr,
" /system <text> Set system prompt\n");
1018 fprintf(stderr,
" /help Show help\n");
1022 if (!opt)
return false;
1023 memset(opt, 0,
sizeof(*opt));
1025 opt->temperature = 0.0f;
1030 opt->eos_ids[0] = 151643;
1031 opt->eos_ids[1] = 151645;
1032 opt->eos_ids[2] = 151644;
1035 for (
int i = 1; i < argc; i++) {
1036 const char *arg = argv[i];
1038 if (!strcmp(arg,
"--help") || !strcmp(arg,
"-h")) {
1041 }
else if (!strcmp(arg,
"--list")) {
1044 }
else if ((!strcmp(arg,
"--model") || !strcmp(arg,
"-m")) && i + 1 < argc) {
1045 opt->model_name = argv[++i];
1046 }
else if (!strcmp(arg,
"--lib") && i + 1 < argc) {
1047 opt->lib_path = argv[++i];
1048 }
else if (!strcmp(arg,
"--weights") && i + 1 < argc) {
1049 opt->weights_path = argv[++i];
1050 }
else if ((!strcmp(arg,
"--prompt") || !strcmp(arg,
"-p")) && i + 1 < argc) {
1051 opt->prompt_once = argv[++i];
1052 }
else if ((!strcmp(arg,
"--system") || !strcmp(arg,
"-S")) && i + 1 < argc) {
1053 opt->system_prompt = argv[++i];
1054 }
else if ((!strcmp(arg,
"--max-tokens") || !strcmp(arg,
"-n")) && i + 1 < argc) {
1055 opt->max_tokens = atoi(argv[++i]);
1056 }
else if ((!strcmp(arg,
"--context") || !strcmp(arg,
"-c")) && i + 1 < argc) {
1057 opt->context_override = atoi(argv[++i]);
1058 }
else if ((!strcmp(arg,
"--temperature") || !strcmp(arg,
"-T")) && i + 1 < argc) {
1059 opt->temperature = (float)atof(argv[++i]);
1060 }
else if (!strcmp(arg,
"--top-p") && i + 1 < argc) {
1061 opt->top_p = (float)atof(argv[++i]);
1062 }
else if (!strcmp(arg,
"--stream") || !strcmp(arg,
"-s")) {
1064 }
else if (!strcmp(arg,
"--no-stream")) {
1065 opt->stream =
false;
1066 }
else if (!strcmp(arg,
"--timing") || !strcmp(arg,
"-t")) {
1068 }
else if (!strcmp(arg,
"--no-timing")) {
1069 opt->timing =
false;
1070 }
else if (!strcmp(arg,
"--no-chat-template")) {
1071 opt->no_chat_template =
true;
1072 }
else if (!strcmp(arg,
"--eos") && i + 1 < argc) {
1074 }
else if (!strcmp(arg,
"--ignore-eos")) {
1075 opt->ignore_eos =
true;
1076 }
else if (!strcmp(arg,
"--verbose") || !strcmp(arg,
"-v")) {
1077 opt->verbose =
true;
1078 }
else if (arg[0] !=
'-') {
1079 if (!opt->lib_path) opt->lib_path = arg;
1080 else if (!opt->weights_path) opt->weights_path = arg;
1082 fprintf(stderr,
"Unknown argument: %s\n", arg);
1086 fprintf(stderr,
"Unknown option: %s\n", arg);
1092 if (opt->model_name && (!opt->lib_path || !opt->weights_path)) {
1093 static char lib_buf[4096], weights_buf[4096];
1095 opt->lib_path = lib_buf;
1096 opt->weights_path = weights_buf;
1098 fprintf(stderr,
"Error: model '%s' not found in cache\n", opt->model_name);
1099 fprintf(stderr,
"Run with --list to see available models\n");
1104 if (!opt->lib_path || !opt->weights_path) {
1110 const char *name_for_template = opt->model_name ? opt->model_name : opt->lib_path;
1116 printf(
"[DEBUG] Loaded %d EOS tokens: ", opt->eos_count);
1117 for (
int i = 0; i < opt->eos_count; i++) {
1118 printf(
"%d ", opt->eos_ids[i]);
1132 if (!line || line[0] !=
'/')
return false;
1134 if (!strncmp(line,
"/exit", 5) || !strncmp(line,
"/quit", 5)) {
1138 if (!strncmp(line,
"/help", 5)) {
1139 printf(
"REPL Commands:\n");
1140 printf(
" /exit, /quit Exit\n");
1141 printf(
" /reset Reset KV cache\n");
1142 printf(
" /timing Toggle timing display\n");
1143 printf(
" /temp <value> Set temperature (0 = greedy)\n");
1144 printf(
" /top-p <value> Set top-p\n");
1145 printf(
" /system <text> Set system prompt\n");
1146 printf(
" /clear Clear system prompt\n");
1147 printf(
" /verbose Toggle verbose mode\n");
1150 if (!strncmp(line,
"/reset", 6)) {
1151 if (api->kv_reset) {
1153 printf(
"[KV cache reset]\n");
1157 if (!strncmp(line,
"/timing", 7)) {
1158 opt->timing = !opt->timing;
1159 printf(
"[Timing %s]\n", opt->timing ?
"enabled" :
"disabled");
1162 if (!strncmp(line,
"/verbose", 8)) {
1163 opt->verbose = !opt->verbose;
1164 printf(
"[Verbose %s]\n", opt->verbose ?
"enabled" :
"disabled");
1167 if (!strncmp(line,
"/temp ", 6)) {
1168 opt->temperature = (float)atof(line + 6);
1169 printf(
"[Temperature set to %.2f]\n", opt->temperature);
1172 if (!strncmp(line,
"/top-p ", 7)) {
1173 opt->top_p = (float)atof(line + 7);
1174 printf(
"[Top-p set to %.2f]\n", opt->top_p);
1177 if (!strncmp(line,
"/system ", 8)) {
1178 opt->system_prompt = strdup(line + 8);
1179 printf(
"[System prompt set]\n");
1182 if (!strncmp(line,
"/clear", 6)) {
1183 opt->system_prompt = NULL;
1184 printf(
"[System prompt cleared]\n");
1188 printf(
"Unknown command: %s\n", line);
1198 srand((
unsigned int)time(NULL));
1206 printf(
"Loading: %s\n", opt.lib_path);
1213 printf(
"Initializing model...\n");
1214 if (api.init(opt.weights_path) != 0) {
1215 fprintf(stderr,
"Error: ck_model_init failed\n");
1219 int ctx = opt.context_override;
1220 if (ctx <= 0 && api.get_context) ctx = api.get_context();
1221 if (api.kv_enable && ctx > 0) {
1227 fprintf(stderr,
"[Tokenizer] failed to create\n");
1231 int vocab_size = api.get_vocab_size ? api.get_vocab_size() : 0;
1232 int vocab_bytes = api.get_vocab_bytes ? api.get_vocab_bytes() : 0;
1233 int num_merges = api.get_num_merges ? api.get_num_merges() : 0;
1234 const int32_t *
offsets = (
const int32_t *)api.get_offsets();
1235 const char *
strings = (
const char *)api.get_strings();
1236 const int32_t *
merges = api.get_merges ? (
const int32_t *)api.get_merges() : NULL;
1239 fprintf(stderr,
"[Tokenizer] missing vocab data in model\n");
1245 fprintf(stderr,
"[Tokenizer] failed to load vocab\n");
1250 printf(
"Ready! Vocab: %d, Context: %d, Template: %s\n",
1252 opt.no_chat_template ?
"none" :
1259 printf(
"[Hardware] %s | Vector: %d-bit | FMA: %s | AI Accel: %s | Kernel: %s\n",
1263 printf(
"Type /help for commands, Ctrl+C to stop generation\n\n");
1265 setvbuf(stdout, NULL, _IOFBF, 1 << 20);
1267 if (opt.prompt_once) {
1268 run_prompt(&api, tokenizer, &opt, opt.prompt_once);
1271 #ifdef HAVE_READLINE
1272 char *home = getenv(
"HOME");
1273 char history_path[4096];
1276 read_history(history_path);
1281 #ifdef HAVE_READLINE
1282 char *line = readline(
"\033[1;32mYou:\033[0m ");
1284 if (*line) add_history(line);
1286 printf(
"\033[1;32mYou:\033[0m ");
1288 char line_buf[4096];
1289 if (!fgets(line_buf,
sizeof(line_buf), stdin)) {
1291 if (errno == EINTR)
break;
1295 size_t len = strlen(line_buf);
1296 if (len > 0 && line_buf[len-1] ==
'\n') line_buf[len-1] =
'\0';
1297 char *line = line_buf;
1300 if (line[0] ==
'\0') {
1301 #ifdef HAVE_READLINE
1307 if (line[0] ==
'/') {
1309 #ifdef HAVE_READLINE
1315 printf(
"\033[1;34mAssistant:\033[0m ");
1319 #ifdef HAVE_READLINE
1324 #ifdef HAVE_READLINE
1326 write_history(history_path);
1332 if (api.free_fn) api.free_fn();
1333 if (api.handle) dlclose(api.handle);
1335 printf(
"\nGoodbye!\n");
void *(* get_ptr_t)(void)
static bool parse_eos_ids(const char *arg, CLIOptions *opt)
int(* init_t)(const char *weights_path)
static bool resolve_symbol(void *handle, const char *name, void **out_ptr, bool required)
static double g_decode_time_ms
static void handle_sigint(int sig)
int(* embed_t)(const int32_t *tokens, int num_tokens)
static char * apply_chat_template(const ChatTemplate *tmpl, const char *system, const char *user)
int(* kv_enable_t)(int capacity)
static int sample_top_p(float *logits, int vocab_size, float temperature, float top_p)
static bool load_eos_from_vocab_json(const char *weights_path, CLIOptions *opt)
static double g_prefill_time_ms
static bool find_model_in_cache(const char *model_name, char *lib_out, char *weights_out, size_t out_size)
int main(int argc, char **argv)
static void print_help(const char *prog)
static ChatTemplateType detect_chat_template(const char *model_name)
#define CK_CLI_HISTORY_FILE
static int g_decode_count
static bool process_repl_command(const char *line, CLIOptions *opt, ModelAPI *api)
static bool is_eos_token(const CLIOptions *opt, int token)
static bool eos_is_potential_prefix(const char *token)
int(* forward_t)(float *logits_out)
float *(* get_logits_t)(void)
static void eos_pattern_init(ChatTemplateType tmpl)
static void output_append(char *buf, size_t *len, const char *text)
static void list_available_models(void)
static volatile sig_atomic_t g_generation_active
#define CK_CLI_MAX_CONTEXT
static int decode_bpe_token(const char *token, char *out, int max)
int(* decode_t)(int32_t token, float *logits_out)
#define EOS_PATTERN_BUF_SIZE
static void print_banner(void)
static bool parse_args(int argc, char **argv, CLIOptions *opt)
static int run_prompt(ModelAPI *api, CKTrueBPE *tokenizer, CLIOptions *opt, const char *input)
static void eos_pattern_reset(void)
static volatile sig_atomic_t g_exit_requested
static EOSPatternState g_eos_state
#define CK_CLI_OUTPUT_BUF_SIZE
#define CK_CLI_DEFAULT_MAX_TOKENS
static void output_flush(char *buf, size_t *len)
static const ChatTemplate g_templates[]
static const char * get_cache_dir(void)
static bool load_model_api(const char *lib_path, ModelAPI *api)
static void output_token(char *buf, size_t *len, const char *token)
static bool eos_pattern_process(const char *token_text, char *out_buf, size_t *out_len, void(*output_fn)(char *, size_t *, const char *), ChatTemplateType tmpl)
int(* sample_argmax_t)(void)
static int g_prompt_tokens
CPU feature detection and dispatch macros.
static ck_capability_t ck_get_capabilities(void)
Get current platform capabilities.
CPU capability information structure.
int32_t int32_t int32_t eos
const int32_t int int * out_len
int ck_true_bpe_encode(CKTrueBPE *bpe, const char *text, int text_len, int32_t *ids, int max_ids)
void ck_true_bpe_free(CKTrueBPE *bpe)
CKTrueBPE * ck_true_bpe_create(void)
int ck_true_bpe_load_binary(CKTrueBPE *bpe, int vocab_size, const int32_t *offsets, const char *strings, int num_merges, const int32_t *merges)
const char * ck_true_bpe_id_to_token(const CKTrueBPE *bpe, int32_t id)
int const int32_t const char int num_merges
int const int32_t const char * strings
int const int32_t const char int const int32_t * merges
int const int32_t * offsets