63 #define MERGE_HASH_SIZE 65536
64 #define INITIAL_TOKEN_CAPACITY 256
65 #define MAX_TOKEN_LEN 128
80 typedef struct CKMergeEntry {
83 struct CKMergeEntry *next;
88 CKMergeEntry **buckets;
108 #define MAX_SPECIAL_TOKENS 32
125 size_t vocab_capacity;
139 int num_special_tokens;
146 size_t str_buffer_size;
160 key *= 0xff51afd7ed558ccdULL;
162 key *= 0xc4ceb9fe1a85ec53ULL;
164 return key % num_buckets;
168 CKMergeTable *table = (CKMergeTable *)malloc(
sizeof(CKMergeTable));
169 if (!table)
return NULL;
171 table->buckets = (CKMergeEntry **)calloc(num_buckets,
sizeof(CKMergeEntry *));
172 if (!table->buckets) {
177 table->num_buckets = num_buckets;
178 table->num_entries = 0;
185 for (
size_t i = 0; i < table->num_buckets; i++) {
186 CKMergeEntry *entry = table->buckets[i];
188 CKMergeEntry *next = entry->next;
194 free(table->buckets);
199 uint64_t key =
merge_key(merge->left_id, merge->right_id);
200 size_t bucket =
merge_hash(key, table->num_buckets);
203 CKMergeEntry *entry = table->buckets[bucket];
205 if (entry->key == key) {
207 entry->merge = *merge;
214 entry = (CKMergeEntry *)malloc(
sizeof(CKMergeEntry));
215 if (!entry)
return -1;
218 entry->merge = *merge;
219 entry->next = table->buckets[bucket];
220 table->buckets[bucket] = entry;
221 table->num_entries++;
228 size_t bucket =
merge_hash(key, table->num_buckets);
230 CKMergeEntry *entry = table->buckets[bucket];
232 if (entry->key == key) {
233 return &entry->merge;
246 CKBPETokenList *list = (CKBPETokenList *)malloc(
sizeof(CKBPETokenList));
247 if (!list)
return NULL;
249 list->tokens = (CKBPEToken *)calloc(initial_capacity,
sizeof(CKBPEToken));
256 list->capacity = initial_capacity;
263 for (
size_t i = 0; i < list->count; i++) {
264 if (list->tokens[i].str) {
265 free(list->tokens[i].str);
274 for (
size_t i = 0; i < list->count; i++) {
275 if (list->tokens[i].str) {
276 free(list->tokens[i].str);
277 list->tokens[i].str = NULL;
284 if (list->count >= list->capacity) {
285 size_t new_cap = list->capacity * 2;
286 CKBPEToken *new_tokens = (CKBPEToken *)realloc(list->tokens, new_cap *
sizeof(CKBPEToken));
287 if (!new_tokens)
return -1;
288 list->tokens = new_tokens;
289 list->capacity = new_cap;
291 memset(list->tokens + list->count, 0, (new_cap - list->count) *
sizeof(CKBPEToken));
294 CKBPEToken *tok = &list->tokens[list->count];
295 tok->str = (
char *)malloc(len + 1);
296 if (!tok->str)
return -1;
298 memcpy(tok->str, str, len);
299 tok->str[len] =
'\0';
300 tok->len = (uint16_t)len;
302 tok->is_merged =
false;
310 if (pos + 1 >= list->count)
return -1;
313 free(list->tokens[pos].str);
314 free(list->tokens[pos + 1].str);
317 list->tokens[pos].str = (
char *)malloc(merged_len + 1);
318 if (!list->tokens[pos].str)
return -1;
320 memcpy(list->tokens[pos].str, merged_str, merged_len);
321 list->tokens[pos].str[merged_len] =
'\0';
322 list->tokens[pos].len = (uint16_t)merged_len;
324 list->tokens[pos].is_merged =
true;
327 for (
size_t i = pos + 1; i < list->count - 1; i++) {
328 list->tokens[i] = list->tokens[i + 1];
333 list->tokens[list->count].str = NULL;
343 CKTrueBPE *bpe = (CKTrueBPE *)calloc(1,
sizeof(CKTrueBPE));
344 if (!bpe)
return NULL;
362 bpe->vocab_capacity = 4096;
363 bpe->id_to_token = (
char **)calloc(bpe->vocab_capacity,
sizeof(
char *));
364 if (!bpe->id_to_token) {
372 bpe->str_buffer_size = 4096;
373 bpe->str_buffer = (
char *)malloc(bpe->str_buffer_size);
374 if (!bpe->str_buffer) {
375 free(bpe->id_to_token);
389 bpe->num_special_tokens = 0;
391 bpe->special_tokens[i].token = NULL;
392 bpe->special_tokens[i].id = -1;
393 bpe->special_tokens[i].len = 0;
397 bpe->config.add_bos =
false;
398 bpe->config.add_eos =
false;
399 bpe->config.byte_fallback =
true;
416 if (bpe->id_to_token) {
417 for (
size_t i = 0; i < bpe->vocab_size; i++) {
418 if (bpe->id_to_token[i]) {
419 free(bpe->id_to_token[i]);
422 free(bpe->id_to_token);
425 if (bpe->str_buffer) {
426 free(bpe->str_buffer);
430 for (
int i = 0; i < bpe->num_special_tokens; i++) {
431 if (bpe->special_tokens[i].token) {
432 free(bpe->special_tokens[i].token);
450 if (!bpe || !
token)
return -1;
453 if (
id >= (int32_t)bpe->vocab_capacity) {
454 size_t new_cap = bpe->vocab_capacity * 2;
455 while (new_cap <= (
size_t)
id) new_cap *= 2;
457 char **new_array = (
char **)realloc(bpe->id_to_token, new_cap *
sizeof(
char *));
458 if (!new_array)
return -1;
460 memset(new_array + bpe->vocab_capacity, 0, (new_cap - bpe->vocab_capacity) *
sizeof(
char *));
461 bpe->id_to_token = new_array;
462 bpe->vocab_capacity = new_cap;
469 existing->score =
score;
470 if (bpe->id_to_token[
id]) free(bpe->id_to_token[
id]);
471 bpe->id_to_token[
id] = strdup(
token);
476 BPETokenInfo *info = (BPETokenInfo *)malloc(
sizeof(BPETokenInfo));
477 if (!info)
return -1;
487 if (
id >= (int32_t)bpe->vocab_size) {
488 bpe->vocab_size =
id + 1;
491 if (bpe->id_to_token[
id]) free(bpe->id_to_token[
id]);
492 bpe->id_to_token[
id] = strdup(
token);
521 if (!left_info || !right_info) {
526 size_t left_len = strlen(
left);
527 size_t right_len = strlen(
right);
528 size_t merged_len = left_len + right_len;
530 if (merged_len >= bpe->str_buffer_size) {
534 memcpy(bpe->str_buffer,
left, left_len);
535 memcpy(bpe->str_buffer + left_len,
right, right_len);
536 bpe->str_buffer[merged_len] =
'\0';
561 if (!bpe || !
config)
return;
566 if (!bpe || !
token ||
id < 0)
return -1;
569 int token_len = (int)strlen(
token);
570 if (token_len == 0)
return -1;
573 for (
int i = 0; i < bpe->num_special_tokens; i++) {
574 if (bpe->special_tokens[i].token &&
575 strcmp(bpe->special_tokens[i].token,
token) == 0) {
577 bpe->special_tokens[i].id =
id;
583 int insert_idx = bpe->num_special_tokens;
584 for (
int i = 0; i < bpe->num_special_tokens; i++) {
585 if (token_len > bpe->special_tokens[i].len) {
592 for (
int i = bpe->num_special_tokens; i > insert_idx; i--) {
593 bpe->special_tokens[i] = bpe->special_tokens[i - 1];
597 bpe->special_tokens[insert_idx].token = strdup(
token);
598 if (!bpe->special_tokens[insert_idx].token)
return -1;
599 bpe->special_tokens[insert_idx].id =
id;
600 bpe->special_tokens[insert_idx].len = token_len;
601 bpe->num_special_tokens++;
625 int32_t merged =
merges[i * 3 + 2];
626 if (
left < 0 ||
right < 0 || merged < 0) {
639 if (!bpe || !
token)
return -1;
642 return info ? info->id : bpe->unk_id;
646 if (!bpe || id < 0 || id >= (int32_t)bpe->vocab_size)
return NULL;
647 return bpe->id_to_token[
id];
658 return bpe->config.space_prefix_style;
665 for (
size_t i = 0; i < bpe->vocab_size && i < 10000; i++) {
666 const char *
token = bpe->id_to_token[i];
667 if (!
token)
continue;
669 unsigned char c0 = (
unsigned char)
token[0];
670 unsigned char c1 = (
unsigned char)
token[1];
672 if (c0 == 0xC4 && c1 == 0xA0) {
674 }
else if (c0 == 0xE2 && c1 == 0x96 && (
unsigned char)
token[2] == 0x81) {
680 bpe->config.space_prefix_style = detected;
706 if (
byte >= 0x21 &&
byte <= 0x7E &&
byte !=
'!') {
714 unsigned int codepoint;
717 if (
byte ==
'!') codepoint = byte;
718 else if (
byte ==
'"') codepoint = byte;
719 else if (
byte >=
'#' &&
byte <=
'~') codepoint = byte;
720 else if (
byte == 0x21) codepoint =
'!';
724 codepoint = 0x100 + byte;
725 }
else if (
byte >= 0x7F &&
byte <= 0xA0) {
726 codepoint = 0x100 + byte;
733 if (codepoint < 0x80) {
734 out[0] = (char)codepoint;
736 }
else if (codepoint < 0x800) {
737 out[0] = (char)(0xC0 | (codepoint >> 6));
738 out[1] = (char)(0x80 | (codepoint & 0x3F));
741 out[0] = (char)(0xE0 | (codepoint >> 12));
742 out[1] = (char)(0x80 | ((codepoint >> 6) & 0x3F));
743 out[2] = (char)(0x80 | (codepoint & 0x3F));
755 if (
out_len + 3 > out_max)
return -1;
761 for (
int i = 0; i <
text_len; i++) {
762 unsigned char byte = (
unsigned char)
text[i];
767 if (
out_len + 3 > out_max)
return -1;
772 if (
out_len + 1 > out_max)
return -1;
779 if (
out_len + enc_len > out_max)
return -1;
780 for (
int j = 0; j < enc_len; j++) {
791 if ((c & 0x80) == 0)
return 1;
792 if ((c & 0xE0) == 0xC0)
return 2;
793 if ((c & 0xF0) == 0xE0)
return 3;
794 if ((c & 0xF8) == 0xF0)
return 4;
816 return (c >=
'a' && c <=
'z') || (c >=
'A' && c <=
'Z');
820 return c >=
'0' && c <=
'9';
824 return c ==
' ' || c ==
'\t' || c ==
'\n' || c ==
'\r';
829 return len >= 2 && (
unsigned char)s[0] == 0xC4 && (
unsigned char)s[1] == 0xA0;
835 unsigned char c = (
unsigned char)s[0];
844 unsigned char c = (
unsigned char)s[0];
867 return len >= 2 && (
unsigned char)s[0] == 0xC4 && (
unsigned char)s[1] == 0x8A;
874 unsigned char c = (
unsigned char)s[0];
886 unsigned char c = (
unsigned char)s[0];
887 return !
is_letter(c) && !
is_digit(c) && c !=
' ' && c !=
'\t' && c !=
'\n' && c !=
'\r';
891 if (len >= 2 && (
unsigned char)s[0] == 0xC4) {
892 unsigned char c1 = (
unsigned char)s[1];
894 if (c1 == 0x8A || c1 == 0x89 || c1 == 0x8D)
return false;
919 while (pos <
text_len && num_chunks < max_chunks) {
920 int chunk_start = pos;
927 bool is_word =
false;
928 int word_start = pos;
936 int after = pos + char_len;
942 prefix_len = char_len;
949 pos = word_start + prefix_len;
958 chunks[num_chunks].start =
text + chunk_start;
959 chunks[num_chunks].len = pos - chunk_start;
968 chunks[num_chunks].start =
text + chunk_start;
969 chunks[num_chunks].len = pos - chunk_start;
979 int punct_start = has_leading_space ? pos + 2 : pos;
985 if (has_leading_space) {
1005 chunks[num_chunks].start =
text + chunk_start;
1006 chunks[num_chunks].len = pos - chunk_start;
1016 int space_count = 0;
1017 int space_end = pos;
1029 chunks[num_chunks].start =
text + pos;
1030 chunks[num_chunks].len = space_end - pos;
1042 if (space_count > 1) {
1043 chunks[num_chunks].start =
text + pos;
1044 chunks[num_chunks].len = (space_count - 1) * 2;
1047 pos += (space_count - 1) * 2;
1048 if (num_chunks >= max_chunks)
break;
1061 chunks[num_chunks].start =
text + chunk_start;
1062 chunks[num_chunks].len = pos - chunk_start;
1068 if (space_count > 1) {
1069 chunks[num_chunks].start =
text + pos;
1070 chunks[num_chunks].len = (space_count - 1) * 2;
1073 pos += (space_count - 1) * 2;
1074 if (num_chunks >= max_chunks)
break;
1079 chunks[num_chunks].start =
text + chunk_start;
1080 chunks[num_chunks].len = pos - chunk_start;
1089 chunks[num_chunks].start =
text + pos;
1090 chunks[num_chunks].len = space_count * 2;
1102 chunks[num_chunks].start =
text + chunk_start;
1103 chunks[num_chunks].len = pos - chunk_start;
1111 chunks[num_chunks].start =
text + chunk_start;
1112 chunks[num_chunks].len = pos - chunk_start;
1133 memcpy(char_buf,
text + pos, char_len);
1134 char_buf[char_len] =
'\0';
1150 size_t *best_pos,
const CKBPEMerge **best_merge) {
1153 int32_t best_priority = INT32_MAX;
1155 for (
size_t i = 0; i + 1 < list->count; i++) {
1156 int32_t
left_id = list->tokens[i].id;
1157 int32_t
right_id = list->tokens[i + 1].id;
1162 if (merge && merge->priority < best_priority) {
1163 best_priority = merge->priority;
1165 *best_merge = merge;
1169 return (*best_merge != NULL) ? 0 : -1;
1176 while (list->count > 1) {
1178 const CKBPEMerge *best_merge;
1185 const char *merged_str = bpe->id_to_token[best_merge->merged_id];
1188 size_t left_len = list->tokens[best_pos].len;
1189 size_t right_len = list->tokens[best_pos + 1].len;
1191 if (left_len + right_len >=
sizeof(merged_buf)) {
1195 memcpy(merged_buf, list->tokens[best_pos].str, left_len);
1196 memcpy(merged_buf + left_len, list->tokens[best_pos + 1].str, right_len);
1197 merged_buf[left_len + right_len] =
'\0';
1198 merged_str = merged_buf;
1202 if (
token_list_merge_at(list, best_pos, merged_str, strlen(merged_str), best_merge->merged_id) != 0) {
1214 int32_t *
ids,
int max_ids, CKBPETokenList *list) {
1215 if (chunk_len <= 0)
return 0;
1218 char chunk_buf[256];
1219 if (chunk_len < (
int)
sizeof(chunk_buf)) {
1220 memcpy(chunk_buf, chunk, chunk_len);
1221 chunk_buf[chunk_len] =
'\0';
1223 if (chunk_id >= 0) {
1243 for (
size_t i = 0; i < list->count && out_idx <
max_ids; i++) {
1244 int32_t
id = list->tokens[i].id;
1248 if (bpe->config.byte_fallback) {
1250 for (
size_t j = 0; j < list->tokens[i].len && out_idx <
max_ids; j++) {
1252 snprintf(byte_token,
sizeof(byte_token),
"<0x%02X>", (
unsigned char)list->tokens[i].str[j]);
1254 ids[out_idx++] = (byte_id >= 0) ? byte_id : bpe->unk_id;
1257 ids[out_idx++] = bpe->unk_id;
1260 ids[out_idx++] =
id;
1275 char preprocessed[16384];
1280 preprocessed[pp_len] =
'\0';
1289 PretokChunk chunks[1024];
1294 if (!list)
return out_idx;
1297 for (
int c = 0; c < num_chunks && out_idx <
max_ids; c++) {
1300 out_idx += chunk_ids;
1307 if (!list)
return out_idx;
1309 int chunk_ids =
encode_chunk(bpe, preprocessed, pp_len,
1311 out_idx += chunk_ids;
1325 const char *cur =
text + pos;
1328 for (
int i = 0; i < bpe->num_special_tokens; i++) {
1329 int tok_len = bpe->special_tokens[i].len;
1330 if (tok_len <= remaining &&
1331 memcmp(cur, bpe->special_tokens[i].token, tok_len) == 0) {
1351 if (bpe->config.add_bos && bpe->bos_id >= 0 && out_idx <
max_ids) {
1352 ids[out_idx++] = bpe->bos_id;
1356 if (bpe->num_special_tokens == 0) {
1361 int segment_start = 0;
1368 if (pos > segment_start) {
1369 int seg_len = pos - segment_start;
1376 ids[out_idx++] = bpe->special_tokens[match].id;
1380 pos += bpe->special_tokens[match].len;
1381 segment_start = pos;
1396 if (bpe->config.add_eos && bpe->eos_id >= 0 && out_idx <
max_ids) {
1397 ids[out_idx++] = bpe->eos_id;
1418 if (len < 2)
return -1;
1421 if ((s[0] & 0xE0) == 0xC0 && (s[1] & 0xC0) == 0x80) {
1422 unsigned int codepoint = ((s[0] & 0x1F) << 6) | (s[1] & 0x3F);
1425 if (codepoint >= 0x100 && codepoint <= 0x1FF) {
1427 if (codepoint <= 0x120) {
1429 return codepoint - 0x100;
1430 }
else if (codepoint >= 0x17F && codepoint <= 0x1A0) {
1432 return codepoint - 0x100;
1444 int32_t
id =
ids[i];
1445 if (
id < 0)
continue;
1448 if (
id == bpe->bos_id ||
id == bpe->eos_id ||
id == bpe->pad_id) {
1453 if (!
token)
continue;
1455 int token_len = (int)strlen(
token);
1458 if (token_len >= 3 &&
1459 (
unsigned char)
token[0] == 0xE2 &&
1460 (
unsigned char)
token[1] == 0x96 &&
1461 (
unsigned char)
token[2] == 0x81) {
1470 while (pos < token_len && len <
max_len - 1) {
1471 unsigned char c0 = (
unsigned char)
token[pos];
1474 if (pos + 1 < token_len && (c0 & 0xE0) == 0xC0) {
1477 text[len++] = (char)decoded;
1485 if ((c0 & 0x80) == 0) char_len = 1;
1486 else if ((c0 & 0xE0) == 0xC0) char_len = 2;
1487 else if ((c0 & 0xF0) == 0xE0) char_len = 3;
1488 else if ((c0 & 0xF8) == 0xF0) char_len = 4;
1491 for (
int j = 0; j < char_len && pos + j < token_len && len <
max_len - 1; j++) {
1507 return bpe ? bpe->vocab_size : 0;
1511 return bpe ? bpe->num_merges : 0;
#define CK_TOKENIZER_HT_BUCKETS_LARGE
void ck_tokenizer_hash_table_free(CKTokenizerHashTable *table, bool free_values)
CKTokenizerHashTable * ck_tokenizer_hash_table_create(size_t bucket_count)
int ck_tokenizer_hash_table_insert(CKTokenizerHashTable *table, const char *key, void *value)
void * ck_tokenizer_hash_table_lookup(CKTokenizerHashTable *table, const char *key)
const int32_t int num_ids
int32_t int32_t int32_t eos
int32_t int32_t int32_t int32_t pad
const int32_t int int * out_len
static bool is_bpe_punct(const char *s, int len)
int ck_true_bpe_decode(const CKTrueBPE *bpe, const int32_t *ids, int num_ids, char *text, int max_len)
static bool is_gpt2_space(const char *s, int len)
static const CKBPEMerge * merge_table_lookup(const CKMergeTable *table, int32_t left_id, int32_t right_id)
int ck_true_bpe_encode(CKTrueBPE *bpe, const char *text, int text_len, int32_t *ids, int max_ids)
static bool is_letter(unsigned char c)
static CKBPETokenList * token_list_create(size_t initial_capacity)
static int encode_text_segment(CKTrueBPE *bpe, const char *text, int text_len, int32_t *ids, int max_ids)
void ck_true_bpe_set_config(CKTrueBPE *bpe, const CKBPEConfig *config)
static void merge_table_free(CKMergeTable *table)
CKSpacePrefixStyle ck_true_bpe_detect_space_style(CKTrueBPE *bpe)
#define MAX_SPECIAL_TOKENS
void ck_true_bpe_free(CKTrueBPE *bpe)
CKTrueBPE * ck_true_bpe_create(void)
void ck_true_bpe_set_special_ids(CKTrueBPE *bpe, int32_t unk, int32_t bos, int32_t eos, int32_t pad)
static bool is_bpe_letter(const char *s, int len)
static int init_tokens_from_text(CKTrueBPE *bpe, CKBPETokenList *list, const char *text, int text_len)
static bool is_bpe_newline(const char *s, int len)
static bool is_bpe_digit(const char *s, int len)
static int gpt2_decode_byte(const unsigned char *s, int len)
int ck_true_bpe_add_merge(CKTrueBPE *bpe, int32_t left_id, int32_t right_id, int32_t merged_id, int32_t priority)
static void token_list_clear(CKBPETokenList *list)
static int encode_chunk(CKTrueBPE *bpe, const char *chunk, int chunk_len, int32_t *ids, int max_ids, CKBPETokenList *list)
static int gpt2_pretokenize(const char *text, int text_len, PretokChunk *chunks, int max_chunks)
static int utf8_char_len(unsigned char c)
static void token_list_free(CKBPETokenList *list)
static int preprocess_text(const CKTrueBPE *bpe, const char *text, int text_len, char *out, int out_max)
static int token_list_append(CKBPETokenList *list, const char *str, size_t len, int32_t id)
static bool is_whitespace(unsigned char c)
int32_t ck_true_bpe_num_merges(const CKTrueBPE *bpe)
int ck_true_bpe_add_special_token(CKTrueBPE *bpe, const char *token, int32_t id)
static int match_special_token(const CKTrueBPE *bpe, const char *text, int text_len, int pos)
int ck_true_bpe_load_binary(CKTrueBPE *bpe, int vocab_size, const int32_t *offsets, const char *strings, int num_merges, const int32_t *merges)
static size_t merge_hash(uint64_t key, size_t num_buckets)
static int apply_bpe_merges(CKTrueBPE *bpe, CKBPETokenList *list)
static int find_best_merge(const CKTrueBPE *bpe, const CKBPETokenList *list, size_t *best_pos, const CKBPEMerge **best_merge)
static bool is_word_prefix_char(const char *s, int len)
static int merge_table_insert(CKMergeTable *table, const CKBPEMerge *merge)
size_t ck_true_bpe_vocab_size(const CKTrueBPE *bpe)
int ck_true_bpe_add_token(CKTrueBPE *bpe, const char *token, int32_t id, float score)
static CKMergeTable * merge_table_create(size_t num_buckets)
static uint64_t merge_key(int32_t left_id, int32_t right_id)
static bool is_digit(unsigned char c)
static int byte_to_gpt2(unsigned char byte, char *out)
static int token_list_merge_at(CKBPETokenList *list, size_t pos, const char *merged_str, size_t merged_len, int32_t merged_id)
int32_t ck_true_bpe_lookup(const CKTrueBPE *bpe, const char *token)
int ck_true_bpe_add_merge_by_tokens(CKTrueBPE *bpe, const char *left, const char *right, int32_t priority)
const char * ck_true_bpe_id_to_token(const CKTrueBPE *bpe, int32_t id)
#define INITIAL_TOKEN_CAPACITY
const CKBPEConfig * config
int const int32_t const char int num_merges
int const int32_t const char * strings
int const int32_t const char int const int32_t * merges
int32_t int32_t int32_t int32_t priority
const int32_t int char int max_len
int const int32_t * offsets
int32_t int32_t int32_t merged_id
const char int int32_t int max_ids
const char const char * right