← Back to C-Kernel-Engine Docs Doxygen Source Documentation
embedding_kernels.c File Reference

Token/position embedding lookup kernels. More...

#include "ckernel_engine.h"
#include "ckernel_dtype.h"
#include <string.h>

Go to the source code of this file.

Functions

void embedding_backward (const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward (const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_q4_k (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_q6_k (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_q8_0 (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 

Detailed Description

Token/position embedding lookup kernels.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Embedding: out[t] = token_embed[token_id[t]] + pos_embed[t]

Definition in file embedding_kernels.c.

Function Documentation

◆ embedding_backward()

void embedding_backward ( const int32_t *  token_ids,
int  token_count,
const float *  d_output,
float *  d_token_embeddings,
float *  d_pos_embeddings,
int  vocab_size,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 241 of file embedding_kernels.c.

251 {
252  if (!token_ids || !d_output || !d_token_embeddings) {
253  return;
254  }
255 
256  int tokens = token_count;
257  if (tokens < 0) {
258  tokens = 0;
259  }
260  if (tokens > context_window) {
261  tokens = context_window;
262  }
263 
264  for (int t = 0; t < tokens; ++t) {
265  int id = token_ids[t];
266  if (id < 0 || id >= vocab_size) {
267  id = 0;
268  }
269 
270  const float *d_out = d_output + (size_t)t * (size_t)aligned_embed_dim;
271  float *d_tok = d_token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
272  float *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
273 
274  for (int d = 0; d < embed_dim; ++d) {
275  float grad = d_out[d];
276  d_tok[d] += grad;
277  if (add_pos && d_pos) {
278  d_pos[d] += grad;
279  }
280  }
281  }
282 }
int vocab_size
Definition: true_bpe.h:185

References vocab_size.

◆ embedding_forward()

void embedding_forward ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const float *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 22 of file embedding_kernels.c.

32 {
33  if (!token_ids || !token_embeddings || !output) {
34  return;
35  }
36 
37  int tokens = token_count;
38  if (tokens < 0) {
39  tokens = 0;
40  }
41  if (tokens > context_window) {
42  tokens = context_window;
43  }
44 
45  for (int t = 0; t < tokens; ++t) {
46  int id = token_ids[t];
47  if (id < 0 || id >= vocab_size) {
48  id = 0;
49  }
50 
51  const float *tok = token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
52  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
53  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
54 
55  if (add_pos && pos) {
56  for (int d = 0; d < embed_dim; ++d) {
57  out[d] = tok[d] + pos[d];
58  }
59  } else {
60  for (int d = 0; d < embed_dim; ++d) {
61  out[d] = tok[d];
62  }
63  }
64 
65  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
66  out[d] = 0.0f;
67  }
68  }
69 
70  for (int t = tokens; t < context_window; ++t) {
71  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
72  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
73  }
74 }

References vocab_size.

◆ embedding_forward_q4_k()

void embedding_forward_q4_k ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const void *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 76 of file embedding_kernels.c.

86 {
87  if (!token_ids || !token_embeddings || !output) {
88  return;
89  }
90 
91  int tokens = token_count;
92  if (tokens < 0) {
93  tokens = 0;
94  }
95  if (tokens > context_window) {
96  tokens = context_window;
97  }
98 
99  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_embed_dim);
100  const uint8_t *base = (const uint8_t *)token_embeddings;
101 
102  for (int t = 0; t < tokens; ++t) {
103  int id = token_ids[t];
104  if (id < 0 || id >= vocab_size) {
105  id = 0;
106  }
107 
108  const void *tok = base + (size_t)id * row_bytes;
109  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
110  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
111 
112  dequant_q4_k_row(tok, out, (size_t)aligned_embed_dim);
113 
114  if (add_pos && pos) {
115  for (int d = 0; d < embed_dim; ++d) {
116  out[d] += pos[d];
117  }
118  }
119 
120  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
121  out[d] = 0.0f;
122  }
123  }
124 
125  for (int t = tokens; t < context_window; ++t) {
126  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
127  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
128  }
129 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)

References CK_DT_Q4_K, ck_dtype_row_bytes(), dequant_q4_k_row(), and vocab_size.

Referenced by model_decode_token(), model_forward_prefill_impl(), qwen2_0_5b_decode_decode_token(), and qwen2_0_5b_decode_forward_prefill_impl().

◆ embedding_forward_q6_k()

void embedding_forward_q6_k ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const void *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 186 of file embedding_kernels.c.

196 {
197  if (!token_ids || !token_embeddings || !output) {
198  return;
199  }
200 
201  int tokens = token_count;
202  if (tokens < 0) {
203  tokens = 0;
204  }
205  if (tokens > context_window) {
206  tokens = context_window;
207  }
208 
209  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_embed_dim);
210  const uint8_t *base = (const uint8_t *)token_embeddings;
211 
212  for (int t = 0; t < tokens; ++t) {
213  int id = token_ids[t];
214  if (id < 0 || id >= vocab_size) {
215  id = 0;
216  }
217 
218  const void *tok = base + (size_t)id * row_bytes;
219  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
220  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
221 
222  dequant_q6_k_row(tok, out, (size_t)aligned_embed_dim);
223 
224  if (add_pos && pos) {
225  for (int d = 0; d < embed_dim; ++d) {
226  out[d] += pos[d];
227  }
228  }
229 
230  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
231  out[d] = 0.0f;
232  }
233  }
234 
235  for (int t = tokens; t < context_window; ++t) {
236  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
237  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
238  }
239 }
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)

References CK_DT_Q6_K, ck_dtype_row_bytes(), dequant_q6_k_row(), and vocab_size.

◆ embedding_forward_q8_0()

void embedding_forward_q8_0 ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const void *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 131 of file embedding_kernels.c.

141 {
142  if (!token_ids || !token_embeddings || !output) {
143  return;
144  }
145 
146  int tokens = token_count;
147  if (tokens < 0) {
148  tokens = 0;
149  }
150  if (tokens > context_window) {
151  tokens = context_window;
152  }
153 
154  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
155  const uint8_t *base = (const uint8_t *)token_embeddings;
156 
157  for (int t = 0; t < tokens; ++t) {
158  int id = token_ids[t];
159  if (id < 0 || id >= vocab_size) {
160  id = 0;
161  }
162 
163  const void *tok = base + (size_t)id * row_bytes;
164  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
165  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
166 
167  dequant_q8_0_row(tok, out, (size_t)aligned_embed_dim);
168 
169  if (add_pos && pos) {
170  for (int d = 0; d < embed_dim; ++d) {
171  out[d] += pos[d];
172  }
173  }
174 
175  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
176  out[d] = 0.0f;
177  }
178  }
179 
180  for (int t = tokens; t < context_window; ++t) {
181  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
182  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
183  }
184 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
void dequant_q8_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)

References CK_DT_Q8_0, ck_dtype_row_bytes(), dequant_q8_0_row(), and vocab_size.

Referenced by qwen2_0_5b_decode_decode_token(), and qwen2_0_5b_decode_forward_prefill_impl().