40 memset(tok, 0,
sizeof(*tok));
107 for (
size_t i = 0; i < tok->
vocab_size; i++) {
134 for (
size_t i = 0; i < tok->
vocab_size; i++) {
158 if (!tok || !
token) {
165 while (new_cap <= (
size_t)
id) {
168 char **new_array = (
char **)realloc(tok->
id_to_token, new_cap *
sizeof(
char *));
181 existing->score =
score;
189 TokenInfo *info = (TokenInfo *)malloc(
sizeof(TokenInfo));
190 if (!info)
return -1;
193 info->is_special =
false;
214 if (!tok || !name)
return -1;
218 if (info) info->is_special =
true;
225 if (strcmp(name,
"<unk>") == 0 || strcmp(name,
"[UNK]") == 0) tok->
unk_id =
id;
226 else if (strcmp(name,
"<s>") == 0 || strcmp(name,
"<bos>") == 0 || strcmp(name,
"[BOS]") == 0) tok->
bos_id =
id;
227 else if (strcmp(name,
"</s>") == 0 || strcmp(name,
"<eos>") == 0 || strcmp(name,
"[EOS]") == 0) tok->
eos_id =
id;
228 else if (strcmp(name,
"<pad>") == 0 || strcmp(name,
"[PAD]") == 0) tok->
pad_id =
id;
291 for (
size_t i = 0; i < tok->
vocab_size && i < 10000; i++) {
293 if (!
token)
continue;
295 unsigned char c0 = (
unsigned char)
token[0];
296 unsigned char c1 = (
unsigned char)
token[1];
299 if (c0 == 0xC4 && c1 == 0xA0) {
303 else if (c0 == 0xE2 && c1 == 0x96 && (
unsigned char)
token[2] == 0x81) {
310 if (spm_count > gpt2_count * 2) {
324 if (!tok || !
token)
return -1;
326 return info ? info->id : tok->
unk_id;
331 if (!tok || !
token)
return -1;
333 return info ? info->id : -1;
340 char *tmp = stack_buf;
341 if (
text_len >= (
int)
sizeof(stack_buf)) {
342 tmp = (
char *)malloc((
size_t)
text_len + 1);
348 if (tmp != stack_buf) free(tmp);
354 if (!tok || id < 0 || id >= (int32_t)tok->
vocab_size)
return NULL;
362 return tok ? tok->
unk_id : -1;
366 return token_id >= 0 ? token_id : tok->
unk_id;
373 return tok ? tok->
unk_id : -1;
379 int32_t best_id = tok->
unk_id;
382 for (
size_t len =
max_len; len >= 1; len--) {
384 memcpy(tmp,
text + pos, len);
395 *match_len = best_len;
417 if (
out_len + 3 > out_max)
return -1;
423 for (
int i = 0; i <
text_len; i++) {
424 if (
text[i] ==
' ') {
427 if (
out_len + 3 > out_max)
return -1;
433 if (
out_len + 2 > out_max)
return -1;
438 if (
out_len + 1 > out_max)
return -1;
460 #define GGUF_TOKEN_NORMAL 1
461 #define GGUF_TOKEN_UNKNOWN 2
462 #define GGUF_TOKEN_CONTROL 3
463 #define GGUF_TOKEN_USER_DEFINED 4
464 #define GGUF_TOKEN_UNUSED 5
465 #define GGUF_TOKEN_BYTE 6
470 if (!tok->
types || token_id < 0 || token_id >= (int32_t)tok->
vocab_size) {
473 uint8_t t = tok->
types[token_id];
480 if (!tok->
types || token_id < 0 || token_id >= (int32_t)tok->
vocab_size) {
494 int len = snprintf(byte_token,
sizeof(byte_token),
"<0x%02X>", byte_val);
495 if (len <= 0)
return tok->
unk_id;
510 tok->
byte_token_id = (int32_t *)malloc(256 *
sizeof(int32_t));
515 for (
int i = 0; i < 256; i++) {
524 size_t len = strlen(
token);
528 unsigned char byte_val = (
unsigned char)
token[0];
532 unsigned int byte_val;
533 if (sscanf(
token,
"<0x%02X>", &byte_val) == 1 && byte_val < 256) {
542 if ((c & 0x80) == 0x00)
return 1;
543 if ((c & 0xE0) == 0xC0)
return 2;
544 if ((c & 0xF0) == 0xE0)
return 3;
545 if ((c & 0xF8) == 0xF0)
return 4;
559 if (
out_len + 3 > out_max)
return -1;
566 if (
text[i] ==
' ') {
576 if (
out_len + 3 > out_max)
return -1;
581 if (
out_len + run > out_max)
return -1;
582 for (
int k = 0; k < run; k++) {
588 if (
out_len + 1 > out_max)
return -1;
612 const SpmLlamaNode *nodes,
617 if (!tok || !nodes || node_id < 0 || !ids || out_idx >=
max_ids) {
621 const SpmLlamaNode *node = &nodes[node_id];
624 ids[out_idx++] = token_id;
628 if (node->left >= 0 && node->right >= 0) {
634 for (
int i = 0; i < node->n && out_idx <
max_ids; i++) {
636 ids[out_idx++] = (byte_token >= 0) ? byte_token : tok->
unk_id;
651 char preprocessed[8192];
654 if (pp_len < 0)
return 0;
655 preprocessed[pp_len] =
'\0';
658 for (
int offs = 0; offs < pp_len;) {
659 int char_len =
utf8_len((
unsigned char)preprocessed[offs]);
660 if (char_len <= 0) char_len = 1;
661 if (offs + char_len > pp_len) char_len = pp_len - offs;
665 if (num_symbols <= 0)
return 0;
667 SpmLlamaSymbol *symbols = (SpmLlamaSymbol *)calloc((
size_t)num_symbols,
sizeof(SpmLlamaSymbol));
668 int node_cap = 2 * num_symbols + 1;
669 SpmLlamaNode *nodes = (SpmLlamaNode *)calloc((
size_t)node_cap,
sizeof(SpmLlamaNode));
670 if (!symbols || !nodes) {
671 if (symbols) free(symbols);
672 if (nodes) free(nodes);
677 for (
int offs = 0; offs < pp_len && index < num_symbols;) {
678 int char_len =
utf8_len((
unsigned char)preprocessed[offs]);
679 if (char_len <= 0) char_len = 1;
680 if (offs + char_len > pp_len) char_len = pp_len - offs;
682 symbols[index].text = preprocessed + offs;
683 symbols[index].n = char_len;
684 symbols[index].prev = index - 1;
685 symbols[index].next = (index + 1 < num_symbols) ? (index + 1) : -1;
686 symbols[index].node_id = index;
688 nodes[index].text = preprocessed + offs;
689 nodes[index].n = char_len;
690 nodes[index].left = -1;
691 nodes[index].right = -1;
697 int node_count = num_symbols;
701 float best_score = -1e30f;
705 if (
right < 0)
continue;
707 int pair_len = symbols[
left].n + symbols[
right].n;
709 if (token_id < 0 || token_id >= (int32_t)tok->
vocab_size)
continue;
716 if (best_left < 0 || score > best_score || (
score == best_score &&
left < best_left)) {
723 if (best_left < 0 || best_right < 0)
break;
724 if (node_count >= node_cap)
break;
726 SpmLlamaSymbol *
left = &symbols[best_left];
727 SpmLlamaSymbol *
right = &symbols[best_right];
729 int new_node_id = node_count++;
730 nodes[new_node_id].text =
left->text;
731 nodes[new_node_id].n =
left->n +
right->n;
732 nodes[new_node_id].left =
left->node_id;
733 nodes[new_node_id].right =
right->node_id;
736 left->node_id = new_node_id;
738 if (
right->next >= 0) {
739 symbols[
right->next].prev = best_left;
748 for (
int i = 0; i != -1 && out_idx <
max_ids; i = symbols[i].next) {
768 while (lead_spaces <
text_len &&
text[lead_spaces] ==
' ') {
773 int trail_spaces = 0;
774 while (trail_spaces <
text_len - lead_spaces &&
780 int content_len =
text_len - lead_spaces - trail_spaces;
781 int starts_with_prefix = (
text_len >= 3 &&
782 (
unsigned char)
text[0] == 0xE2 &&
783 (
unsigned char)
text[1] == 0x96 &&
784 (
unsigned char)
text[2] == 0x81);
785 int inserted_prefix = 0;
787 if (
out_len + 3 > out_max)
return -1;
796 int last_was_space = (starts_with_prefix || inserted_prefix) ? 1 : 0;
797 while (i <
text_len - trail_spaces) {
798 if (
text[i] ==
' ') {
799 if (!last_was_space) {
801 if (
out_len + 3 > out_max)
return -1;
809 if (
out_len + 1 > out_max)
return -1;
821 size_t pos, int32_t *candidates,
int max_candidates);
840 const int dbg = getenv(
"CK_DEBUG_SPM_ENCODE") ? 1 : 0;
842 fprintf(stderr,
"[SPM] encode start: text_len=%d max_ids=%d\n",
text_len,
max_ids);
846 char preprocessed[8192];
849 if (pp_len < 0)
return 0;
850 preprocessed[pp_len] =
'\0';
852 fprintf(stderr,
"[SPM] preprocessed len=%d: \"%.*s\"\n", pp_len, pp_len, preprocessed);
856 size_t n = (size_t)pp_len + 1;
857 float *best_score = (
float *)malloc(n *
sizeof(
float));
858 int32_t *best_prev = (int32_t *)malloc(n *
sizeof(int32_t));
859 int32_t *best_token = (int32_t *)malloc(n *
sizeof(int32_t));
861 fprintf(stderr,
"[SPM] DP alloc n=%zu\n", n);
864 if (!best_score || !best_prev || !best_token) {
865 if (best_score) free(best_score);
866 if (best_prev) free(best_prev);
867 if (best_token) free(best_token);
872 const float neg_inf = -1e30f;
873 const float unknown_penalty = -10.0f;
874 for (
size_t i = 0; i < n; i++) {
875 best_score[i] = neg_inf;
879 best_score[0] = 0.0f;
882 for (
size_t pos = 0; pos < n; pos++) {
883 if (best_score[pos] == neg_inf)
continue;
886 int32_t candidates[64];
888 if (dbg && pos < 8) {
889 fprintf(stderr,
"[SPM] pos=%zu cand=%d\n", pos, num_cand);
892 for (
int c = 0; c < num_cand; c++) {
893 int32_t token_id = candidates[c];
902 if (!
token)
continue;
905 int token_len = (int)strlen(
token);
908 if (token_id == tok->
unk_id) {
910 if (token_len == 0) token_len = 1;
913 size_t next_pos = pos + token_len;
915 if (next_pos >= n)
continue;
918 float token_score = 0.0f;
920 token_score = tok->
scores[token_id];
924 if (tok->
types && token_id >= 0 && token_id < (int32_t)tok->
types_size) {
930 if (token_id == tok->
unk_id) {
931 token_score += unknown_penalty;
935 float new_score = best_score[pos] + token_score;
937 if (new_score > best_score[next_pos]) {
938 best_score[next_pos] = new_score;
939 best_prev[next_pos] = (int32_t)pos;
940 best_token[next_pos] = token_id;
946 int32_t *reverse_ids = (int32_t *)malloc(
max_ids *
sizeof(int32_t));
955 int32_t curr = (int32_t)(n - 1);
958 while (curr > 0 && best_token[curr] < 0) {
959 curr = best_prev[curr];
965 while (curr > 0 && num_tokens <
max_ids) {
966 int32_t token_id = best_token[curr];
969 int token_start = best_prev[curr];
972 if (token_start != last_start) {
973 reverse_ids[num_tokens++] = token_id;
974 last_start = token_start;
977 curr = best_prev[curr];
980 fprintf(stderr,
"[SPM] backtrack tokens=%d curr=%d\n", num_tokens, curr);
991 for (
int i = 0; i < num_tokens / 2; i++) {
992 int32_t tmp = reverse_ids[i];
993 reverse_ids[i] = reverse_ids[num_tokens - 1 - i];
994 reverse_ids[num_tokens - 1 - i] = tmp;
999 for (
int i = 0; i < num_tokens && out_idx <
max_ids; i++) {
1000 int32_t token_id = reverse_ids[i];
1003 if (token_id == tok->
unk_id && out_idx > 0 &&
ids[out_idx - 1] == tok->
unk_id) {
1006 ids[out_idx++] = token_id;
1009 fprintf(stderr,
"[SPM] encode done: out=%d\n", out_idx);
1015 if (num_tokens == 0) {
1033 unsigned char byte_val = (
unsigned char)
text[i];
1037 if (byte_token >= 0 && byte_token != tok->
unk_id) {
1038 ids[count++] = byte_token;
1048 size_t pos, int32_t *candidates,
int max_candidates) {
1049 if (!tok || !
text || pos >= (
size_t)
text_len)
return 0;
1057 for (
int len =
max_len; len >= 1 && num_found < max_candidates; len--) {
1058 memcpy(tmp,
text + pos, len);
1062 if (info && info->id >= 0 && info->id != tok->
unk_id) {
1070 for (
int j = 0; j < num_found; j++) {
1071 if (candidates[j] == info->id) {
1077 candidates[num_found++] = info->id;
1086 if (num_found == 0 && tok->
unk_id >= 0 && max_candidates > 0) {
1089 candidates[num_found++] = tok->
unk_id;
1100 while (pos + run < (
size_t)
text_len) {
1102 if (pos + run + 3 <= (
size_t)
text_len &&
1103 (
unsigned char)
text[pos + run] == 0xE2 &&
1104 (
unsigned char)
text[pos + run + 1] == 0x96 &&
1105 (
unsigned char)
text[pos + run + 2] == 0x81) {
1115 for (
int len =
max_len; len >= 1; len--) {
1117 memcpy(tmp,
text + pos + run, len);
1154 if (n <= 0)
return n;
1164 char preprocessed[8192];
1165 const char *input =
text;
1177 preprocessed[pp_len] =
'\0';
1178 input = preprocessed;
1189 while (pos < (
size_t)input_len && out_idx <
max_ids) {
1190 size_t match_len = 0;
1193 if (match_len == 0) {
1198 ids[out_idx++] =
id;
1214 for (
int i = 0; i <
num_ids; i++) {
1215 int32_t
id =
ids[i];
1216 if (
id < 0)
continue;
1218 if (!
token)
continue;
1219 int token_len = (int)strlen(
token);
1222 unsigned char c0 = (
unsigned char)
token[0];
1223 unsigned char c1 = (
unsigned char)
token[1];
1225 if (c0 == 0xC4 && c1 == 0xA0) {
1228 token += 2; token_len -= 2;
1229 }
else if (c0 == 0xE2 && c1 == 0x96 && (
unsigned char)
token[2] == 0x81) {
1232 token += 3; token_len -= 3;
1235 for (
int j = 0; j < token_len && len <
max_len - 1; j++)
text[len++] =
token[j];
1256 const float *scores,
1257 const uint8_t *types,
1278 if (!tok->
scores)
return -1;
1297 float score = scores ? scores[i] : 0.0f;
1306 int count_normal = 0, count_unknown = 0, count_control = 0, count_byte = 0, count_other = 0;
1309 uint8_t t = tok->
types[i];
1310 if (t > max_type) max_type = t;
1316 default: count_other++;
break;
1319 fprintf(stderr,
"[TOKENIZER] Loaded %d tokens: normal=%d, unknown=%d, control=%d, byte=%d, other=%d\n",
1320 vocab_size, count_normal, count_unknown, count_control, count_byte, count_other);
1322 fprintf(stderr,
"[TOKENIZER] Warning: Unexpected token type %d\n", max_type);
#define CK_TOKENIZER_HT_BUCKETS_LARGE
void ck_tokenizer_hash_table_free(CKTokenizerHashTable *table, bool free_values)
CKTokenizerHashTable * ck_tokenizer_hash_table_create(size_t bucket_count)
int ck_tokenizer_hash_table_insert(CKTokenizerHashTable *table, const char *key, void *value)
void * ck_tokenizer_hash_table_lookup(CKTokenizerHashTable *table, const char *key)
void ck_tokenizer_hash_table_clear(CKTokenizerHashTable *table, bool free_values)
void ck_trie_clear(CKTrie *trie)
int32_t ck_trie_find_longest(const CKTrie *trie, const char *text, size_t text_len, size_t start_pos, size_t *match_len)
int ck_trie_insert(CKTrie *trie, const char *token, int32_t token_id, bool is_special, int32_t priority)
void ck_trie_free(CKTrie *trie)
CKTrie * ck_trie_create(size_t max_nodes)
int ck_tokenizer_mempool_init(CKTokenizerMemPool *pool, size_t size)
void ck_tokenizer_mempool_free(CKTokenizerMemPool *pool)
bool space_prefix_detected
CKSpacePrefixStyle space_prefix_style
CKTokenizerHashTable * vocab
static int spm_find_candidates_at_pos(const CKTokenizer *tok, const char *text, int text_len, size_t pos, int32_t *candidates, int max_candidates)
static int32_t ck_tokenizer_lookup_exact(const CKTokenizer *tok, const char *token)
void ck_tokenizer_set_add_bos_eos(CKTokenizer *tok, bool add_bos, bool add_eos)
static int preprocess_spm_llama_text(const char *text, int text_len, char *out, int out_max, bool add_space_prefix)
static bool spm_token_is_byte_format(const char *token)
int32_t ck_tokenizer_lookup(const CKTokenizer *tok, const char *token)
int ck_tokenizer_load_binary_with_scores(CKTokenizer *tok, int vocab_size, const int32_t *offsets, const char *strings, const float *scores, const uint8_t *types, int num_merges, const int32_t *merges)
int ck_tokenizer_decode(const CKTokenizer *tok, const int32_t *ids, int num_ids, char *text, int max_len)
int ck_tokenizer_add_token(CKTokenizer *tok, const char *token, int32_t id, float score)
static int ck_tokenizer_encode_spm_llama_impl(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
CKSpacePrefixStyle ck_tokenizer_detect_space_prefix_style(CKTokenizer *tok)
static bool spm_token_allowed_in_dp(const CKTokenizer *tok, int32_t token_id)
#define GGUF_TOKEN_CONTROL
int ck_tokenizer_load_text(CKTokenizer *tok, const char *path)
int ck_tokenizer_load_gguf(CKTokenizer *tok, const char *path)
int ck_tokenizer_load_json(CKTokenizer *tok, const char *path)
void ck_tokenizer_set_spm_mode(CKTokenizer *tok, CKSpmMode spm_mode)
static void spm_build_byte_lookup(CKTokenizer *tok, const char *strings, const int32_t *offsets, int vocab_size)
CKTokenizer * ck_tokenizer_create(CKTokenizerType type)
static int32_t find_longest_match(const CKTokenizer *tok, const char *text, size_t text_len, size_t pos, size_t *match_len)
int ck_tokenizer_load_binary(CKTokenizer *tok, int vocab_size, const int32_t *offsets, const char *strings, int num_merges, const int32_t *merges)
static int preprocess_spm_text(const char *text, int text_len, char *out, int out_max, bool add_space_prefix)
static int spm_encode_byte_fallback(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
static int spm_llama_resegment_node(const CKTokenizer *tok, const SpmLlamaNode *nodes, int node_id, int32_t *ids, int max_ids, int out_idx)
static int32_t ck_tokenizer_lookup_exact_n(const CKTokenizer *tok, const char *text, int text_len)
#define GGUF_TOKEN_USER_DEFINED
const char * ck_tokenizer_id_to_token(const CKTokenizer *tok, int32_t id)
int ck_tokenizer_add_merge(CKTokenizer *tok, int32_t left, int32_t right, int32_t merged, int32_t priority)
static int utf8_len(unsigned char c)
int ck_tokenizer_encode(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
void ck_tokenizer_set_special_ids(CKTokenizer *tok, int32_t unk, int32_t bos, int32_t eos, int32_t pad, int32_t mask)
#define GGUF_TOKEN_UNUSED
static int32_t find_longest_match_trie(const CKTokenizer *tok, const char *text, size_t text_len, size_t pos, size_t *match_len)
void ck_tokenizer_reset(CKTokenizer *tok)
static bool spm_is_byte_token(const CKTokenizer *tok, int32_t token_id)
static int32_t spm_get_byte_token(const CKTokenizer *tok, unsigned char byte_val)
int ck_tokenizer_add_special_token(CKTokenizer *tok, const char *name, int32_t id)
void ck_tokenizer_free(CKTokenizer *tok)
int ck_tokenizer_load_merges(CKTokenizer *tok, const char *path)
static int spm_count_unknown_run(const CKTokenizer *tok, const char *text, int text_len, size_t pos)
static int preprocess_bpe_spaces(const char *text, int text_len, char *out, int out_max, CKSpacePrefixStyle style)
#define GGUF_TOKEN_UNKNOWN
void ck_tokenizer_set_use_trie(CKTokenizer *tok, bool use_trie)
void ck_tokenizer_set_add_space_prefix(CKTokenizer *tok, bool add_space_prefix)
static int32_t find_longest_match_hash(const CKTokenizer *tok, const char *text, size_t text_len, size_t pos, size_t *match_len)
void ck_tokenizer_set_space_prefix_style(CKTokenizer *tok, CKSpacePrefixStyle style)
static int ck_tokenizer_encode_spm_impl(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
#define GGUF_TOKEN_NORMAL
int32_t int32_t int32_t int32_t int32_t mask
const int32_t int num_ids
int32_t int32_t int32_t eos
int32_t int32_t int32_t int32_t pad
const int32_t int int * out_len
const CKBPEConfig * config
int const int32_t const char int num_merges
int const int32_t const char * strings
int const int32_t const char int const int32_t * merges
int32_t int32_t int32_t int32_t priority
const int32_t int char int max_len
int const int32_t * offsets
const char int int32_t int max_ids
const char const char * right