← Back to C-Kernel-Engine Docs Doxygen Source Documentation
embedding_kernels.c
Go to the documentation of this file.
1 /**
2  * @file embedding_kernels.c
3  * @brief Token/position embedding lookup kernels
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Embedding: out[t] = token_embed[token_id[t]] + pos_embed[t]
15  */
16 
17 #include "ckernel_engine.h"
18 #include "ckernel_dtype.h"
19 
20 #include <string.h>
21 
22 void embedding_forward(const int32_t *token_ids,
23  int token_count,
24  int vocab_size,
25  const float *token_embeddings,
26  const float *pos_embeddings,
27  float *output,
28  int embed_dim,
29  int aligned_embed_dim,
30  int context_window,
31  int add_pos)
32 {
33  if (!token_ids || !token_embeddings || !output) {
34  return;
35  }
36 
37  int tokens = token_count;
38  if (tokens < 0) {
39  tokens = 0;
40  }
41  if (tokens > context_window) {
42  tokens = context_window;
43  }
44 
45  for (int t = 0; t < tokens; ++t) {
46  int id = token_ids[t];
47  if (id < 0 || id >= vocab_size) {
48  id = 0;
49  }
50 
51  const float *tok = token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
52  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
53  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
54 
55  if (add_pos && pos) {
56  for (int d = 0; d < embed_dim; ++d) {
57  out[d] = tok[d] + pos[d];
58  }
59  } else {
60  for (int d = 0; d < embed_dim; ++d) {
61  out[d] = tok[d];
62  }
63  }
64 
65  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
66  out[d] = 0.0f;
67  }
68  }
69 
70  for (int t = tokens; t < context_window; ++t) {
71  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
72  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
73  }
74 }
75 
76 void embedding_forward_q4_k(const int32_t *token_ids,
77  int token_count,
78  int vocab_size,
79  const void *token_embeddings,
80  const float *pos_embeddings,
81  float *output,
82  int embed_dim,
83  int aligned_embed_dim,
84  int context_window,
85  int add_pos)
86 {
87  if (!token_ids || !token_embeddings || !output) {
88  return;
89  }
90 
91  int tokens = token_count;
92  if (tokens < 0) {
93  tokens = 0;
94  }
95  if (tokens > context_window) {
96  tokens = context_window;
97  }
98 
99  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_embed_dim);
100  const uint8_t *base = (const uint8_t *)token_embeddings;
101 
102  for (int t = 0; t < tokens; ++t) {
103  int id = token_ids[t];
104  if (id < 0 || id >= vocab_size) {
105  id = 0;
106  }
107 
108  const void *tok = base + (size_t)id * row_bytes;
109  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
110  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
111 
112  dequant_q4_k_row(tok, out, (size_t)aligned_embed_dim);
113 
114  if (add_pos && pos) {
115  for (int d = 0; d < embed_dim; ++d) {
116  out[d] += pos[d];
117  }
118  }
119 
120  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
121  out[d] = 0.0f;
122  }
123  }
124 
125  for (int t = tokens; t < context_window; ++t) {
126  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
127  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
128  }
129 }
130 
131 void embedding_forward_q8_0(const int32_t *token_ids,
132  int token_count,
133  int vocab_size,
134  const void *token_embeddings,
135  const float *pos_embeddings,
136  float *output,
137  int embed_dim,
138  int aligned_embed_dim,
139  int context_window,
140  int add_pos)
141 {
142  if (!token_ids || !token_embeddings || !output) {
143  return;
144  }
145 
146  int tokens = token_count;
147  if (tokens < 0) {
148  tokens = 0;
149  }
150  if (tokens > context_window) {
151  tokens = context_window;
152  }
153 
154  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
155  const uint8_t *base = (const uint8_t *)token_embeddings;
156 
157  for (int t = 0; t < tokens; ++t) {
158  int id = token_ids[t];
159  if (id < 0 || id >= vocab_size) {
160  id = 0;
161  }
162 
163  const void *tok = base + (size_t)id * row_bytes;
164  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
165  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
166 
167  dequant_q8_0_row(tok, out, (size_t)aligned_embed_dim);
168 
169  if (add_pos && pos) {
170  for (int d = 0; d < embed_dim; ++d) {
171  out[d] += pos[d];
172  }
173  }
174 
175  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
176  out[d] = 0.0f;
177  }
178  }
179 
180  for (int t = tokens; t < context_window; ++t) {
181  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
182  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
183  }
184 }
185 
186 void embedding_forward_q6_k(const int32_t *token_ids,
187  int token_count,
188  int vocab_size,
189  const void *token_embeddings,
190  const float *pos_embeddings,
191  float *output,
192  int embed_dim,
193  int aligned_embed_dim,
194  int context_window,
195  int add_pos)
196 {
197  if (!token_ids || !token_embeddings || !output) {
198  return;
199  }
200 
201  int tokens = token_count;
202  if (tokens < 0) {
203  tokens = 0;
204  }
205  if (tokens > context_window) {
206  tokens = context_window;
207  }
208 
209  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_embed_dim);
210  const uint8_t *base = (const uint8_t *)token_embeddings;
211 
212  for (int t = 0; t < tokens; ++t) {
213  int id = token_ids[t];
214  if (id < 0 || id >= vocab_size) {
215  id = 0;
216  }
217 
218  const void *tok = base + (size_t)id * row_bytes;
219  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
220  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
221 
222  dequant_q6_k_row(tok, out, (size_t)aligned_embed_dim);
223 
224  if (add_pos && pos) {
225  for (int d = 0; d < embed_dim; ++d) {
226  out[d] += pos[d];
227  }
228  }
229 
230  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
231  out[d] = 0.0f;
232  }
233  }
234 
235  for (int t = tokens; t < context_window; ++t) {
236  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
237  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
238  }
239 }
240 
241 void embedding_backward(const int32_t *token_ids,
242  int token_count,
243  const float *d_output,
244  float *d_token_embeddings,
245  float *d_pos_embeddings,
246  int vocab_size,
247  int embed_dim,
248  int aligned_embed_dim,
249  int context_window,
250  int add_pos)
251 {
252  if (!token_ids || !d_output || !d_token_embeddings) {
253  return;
254  }
255 
256  int tokens = token_count;
257  if (tokens < 0) {
258  tokens = 0;
259  }
260  if (tokens > context_window) {
261  tokens = context_window;
262  }
263 
264  for (int t = 0; t < tokens; ++t) {
265  int id = token_ids[t];
266  if (id < 0 || id >= vocab_size) {
267  id = 0;
268  }
269 
270  const float *d_out = d_output + (size_t)t * (size_t)aligned_embed_dim;
271  float *d_tok = d_token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
272  float *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
273 
274  for (int d = 0; d < embed_dim; ++d) {
275  float grad = d_out[d];
276  d_tok[d] += grad;
277  if (add_pos && d_pos) {
278  d_pos[d] += grad;
279  }
280  }
281  }
282 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void dequant_q8_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
void embedding_forward_q6_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_backward(const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_forward_q8_0(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_forward(const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
int vocab_size
Definition: true_bpe.h:185