52 memset(pool, 0,
sizeof(*pool));
57 if (!block)
return NULL;
58 block->
data = (uint8_t *)malloc(capacity);
71 size = (size + 7) & ~7;
83 if (size > block_size) block_size = size;
86 if (!block)
return NULL;
92 void *ptr = block->
data;
99 if (len < 0) len = (int)strlen(s);
101 if (!copy)
return NULL;
102 memcpy(copy, s, len);
115 memset(pool, 0,
sizeof(*pool));
124 uint32_t hash = 2166136261u;
125 for (
int i = 0; i < len; i++) {
126 hash ^= (uint8_t)s[i];
134 uint64_t combined = ((uint64_t)
left << 32) | (uint32_t)
right;
136 combined ^= combined >> 33;
137 combined *= 0xff51afd7ed558ccdULL;
138 combined ^= combined >> 33;
139 combined *= 0xc4ceb9fe1a85ec53ULL;
140 combined ^= combined >> 33;
141 return (uint32_t)combined;
149 memset(tok, 0,
sizeof(*tok));
189 memset(tok, 0,
sizeof(*tok));
197 if (len < 0) len = (int)strlen(
token);
202 if (existing != tok->
unk_id || (len == 0)) {
208 if (!entry)
return -1;
211 if (!entry->
token)
return -1;
228 if (len < 0) len = (int)strlen(
token);
232 if (e->token_len == len && memcmp(e->token,
token, len) == 0) {
240 if (id < 0 || id >= tok->
vocab_size)
return NULL;
252 if (idx % 4096 == 0) {
253 size_t new_cap = (idx + 4096) *
sizeof(
CKMergeRule);
255 if (!new_merges)
return -1;
281 while (tok->
merge_hash[bucket] >= 0 && probes < tok->merge_hash_size) {
297 while (p->pos < p->end && isspace((
unsigned char)*p->pos)) {
304 if (p->pos < p->end && *p->pos == c) {
313 if (p->pos >= p->end || *p->pos !=
'"')
return -1;
317 while (p->pos < p->end && *p->pos !=
'"') {
319 if (c ==
'\\' && p->pos < p->end) {
322 case 'n': c =
'\n';
break;
323 case 'r': c =
'\r';
break;
324 case 't': c =
'\t';
break;
325 case '\\': c =
'\\';
break;
326 case '"': c =
'"';
break;
329 if (p->pos + 4 <= p->end) {
330 char hex[5] = {p->pos[0], p->pos[1], p->pos[2], p->pos[3], 0};
331 unsigned int codepoint = (
unsigned int)strtol(hex, NULL, 16);
334 if (codepoint < 0x80) {
335 if (len <
max_len - 1) buf[len++] = (char)codepoint;
336 }
else if (codepoint < 0x800) {
338 buf[len++] = (char)(0xC0 | (codepoint >> 6));
339 buf[len++] = (char)(0x80 | (codepoint & 0x3F));
343 buf[len++] = (char)(0xE0 | (codepoint >> 12));
344 buf[len++] = (char)(0x80 | ((codepoint >> 6) & 0x3F));
345 buf[len++] = (char)(0x80 | (codepoint & 0x3F));
355 if (len <
max_len - 1) buf[len++] = c;
359 if (p->pos < p->end && *p->pos ==
'"') p->pos++;
365 if (p->pos >= p->end)
return -1;
368 if (*p->pos ==
'-') {
373 if (p->pos >= p->end || !isdigit((
unsigned char)*p->pos))
return -1;
376 while (p->pos < p->end && isdigit((
unsigned char)*p->pos)) {
377 val = val * 10 + (*p->pos -
'0');
381 *out = neg ? -val : val;
387 if (p->pos >= p->end)
return;
393 }
else if (c ==
'{') {
396 while (p->pos < p->end && depth > 0) {
397 if (*p->pos ==
'{') depth++;
398 else if (*p->pos ==
'}') depth--;
399 else if (*p->pos ==
'"') {
406 }
else if (c ==
'[') {
409 while (p->pos < p->end && depth > 0) {
410 if (*p->pos ==
'[') depth++;
411 else if (*p->pos ==
']') depth--;
412 else if (*p->pos ==
'"') {
421 while (p->pos < p->end && !isspace((
unsigned char)*p->pos) &&
422 *p->pos !=
',' && *p->pos !=
'}' && *p->pos !=
']') {
433 FILE *f = fopen(path,
"rb");
435 fprintf(stderr,
"Failed to open tokenizer: %s\n", path);
439 fseek(f, 0, SEEK_END);
440 long size = ftell(f);
441 fseek(f, 0, SEEK_SET);
443 char *data = (
char *)malloc(size + 1);
448 fread(data, 1, size, f);
452 JSONParser parser = {data, data, data + size};
453 JSONParser *p = &parser;
462 while (p->pos < p->end && *p->pos !=
'}') {
466 if (strcmp(key,
"model") == 0) {
474 while (p->pos < p->end && *p->pos !=
'}') {
478 if (strcmp(key,
"vocab") == 0) {
487 while (p->pos < p->end && *p->pos !=
'}') {
489 if (token_len < 0)
break;
515 }
else if (strcmp(key,
"merges") == 0) {
524 while (p->pos < p->end && *p->pos !=
']') {
526 if (merge_len < 0)
break;
529 char *space = strchr(merge_str,
' ');
532 char *tok1 = merge_str;
533 char *tok2 = space + 1;
540 snprintf(merged,
sizeof(merged),
"%s%s", tok1, tok2);
562 }
else if (strcmp(key,
"added_tokens") == 0) {
570 while (p->pos < p->end && *p->pos !=
']') {
577 char content[256] =
"";
579 bool special =
false;
581 while (p->pos < p->end && *p->pos !=
'}') {
585 if (strcmp(key,
"content") == 0) {
587 }
else if (strcmp(key,
"id") == 0) {
589 }
else if (strcmp(key,
"special") == 0) {
591 special = (p->pos < p->end && *p->pos ==
't');
600 if (
id >= 0 && content[0]) {
602 if (strcmp(content,
"<unk>") == 0 || strcmp(content,
"[UNK]") == 0) {
604 }
else if (strcmp(content,
"<s>") == 0 || strcmp(content,
"<bos>") == 0 ||
605 strcmp(content,
"[BOS]") == 0) {
607 }
else if (strcmp(content,
"</s>") == 0 || strcmp(content,
"<eos>") == 0 ||
608 strcmp(content,
"[EOS]") == 0 || strcmp(content,
"<|endoftext|>") == 0) {
610 }
else if (strcmp(content,
"<pad>") == 0 || strcmp(content,
"[PAD]") == 0) {
649 int32_t *tokens = (int32_t *)malloc(
text_len *
sizeof(int32_t));
652 for (
int i = 0; i <
text_len; i++) {
654 char c[2] = {
text[i],
'\0'};
660 snprintf(byte_token,
sizeof(byte_token),
"<0x%02X>", (
unsigned char)
text[i]);
665 if (
id == tok->
unk_id && (
unsigned char)
text[i] >= 0x80) {
674 tokens[num_tokens++] =
id;
681 tokens[num_tokens++] =
id;
686 while (changed && num_tokens > 1) {
693 for (
int i = 0; i < num_tokens - 1; i++) {
695 if (merge_idx >= 0 && tok->
merges[merge_idx].
priority < best_priority) {
706 for (
int i = best_pos + 1; i < num_tokens - 1; i++) {
707 tokens[i] = tokens[i + 1];
744 for (
int i = 0; i <
num_ids; i++) {
751 if (!
token)
continue;
753 int token_len = (int)strlen(
token);
756 if (token_len == 6 &&
token[0] ==
'<' &&
token[1] ==
'0' &&
token[2] ==
'x') {
758 unsigned int byte = (
unsigned int)strtol(hex, NULL, 16);
760 text[len++] = (char)
byte;
766 const char *src =
token;
767 if ((
unsigned char)
token[0] == 0xC4 && (
unsigned char)
token[1] == 0xA0) {
776 for (
int j = 0; j < token_len && len <
max_len - 1; j++) {
777 text[len++] = src[j];
int ck_tokenizer_add_merge(CKTokenizer *tok, int32_t left, int32_t right, int32_t merged)
void ck_pool_init(CKMemPool *pool)
int32_t ck_tokenizer_lookup(const CKTokenizer *tok, const char *token, int len)
int ck_tokenizer_decode(const CKTokenizer *tok, const int32_t *ids, int num_ids, char *text, int max_len)
int ck_tokenizer_init(CKTokenizer *tok)
static uint32_t hash_string(const char *s, int len)
static void json_skip_whitespace(JSONParser *p)
const char * ck_tokenizer_id_to_token(const CKTokenizer *tok, int32_t id)
static CKPoolBlock * pool_new_block(size_t capacity)
int ck_tokenizer_load(CKTokenizer *tok, const char *path)
void * ck_pool_alloc(CKMemPool *pool, size_t size)
static int json_match_char(JSONParser *p, char c)
char * ck_pool_strdup(CKMemPool *pool, const char *s, int len)
int ck_tokenizer_encode(const CKTokenizer *tok, const char *text, int text_len, int32_t *ids, int max_ids)
int32_t ck_tokenizer_add_token(CKTokenizer *tok, const char *token, int len)
void ck_pool_free(CKMemPool *pool)
static int json_parse_string(JSONParser *p, char *buf, int max_len)
void ck_tokenizer_free(CKTokenizer *tok)
int ck_tokenizer_lookup_merge(const CKTokenizer *tok, int32_t left, int32_t right)
static uint32_t hash_pair(int32_t left, int32_t right)
static void json_skip_value(JSONParser *p)
static int json_parse_int(JSONParser *p, int *out)
#define CK_MAX_VOCAB_SIZE
#define CK_POOL_BLOCK_SIZE
struct CKPoolBlock * next
CKVocabEntry ** vocab_hash
struct CKVocabEntry * next
static int utf8_len(unsigned char c)
const int32_t int num_ids
const int32_t int int * out_len
const int32_t int char int max_len
int32_t int32_t int32_t merged_id
const char int int32_t int max_ids
const char const char * right