25 const float *token_embeddings,
26 const float *pos_embeddings,
29 int aligned_embed_dim,
33 if (!token_ids || !token_embeddings || !output) {
37 int tokens = token_count;
41 if (tokens > context_window) {
42 tokens = context_window;
45 for (
int t = 0; t < tokens; ++t) {
46 int id = token_ids[t];
51 const float *tok = token_embeddings + (size_t)
id * (
size_t)aligned_embed_dim;
52 const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
53 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
56 for (
int d = 0; d < embed_dim; ++d) {
57 out[d] = tok[d] + pos[d];
60 for (
int d = 0; d < embed_dim; ++d) {
65 for (
int d = embed_dim; d < aligned_embed_dim; ++d) {
70 for (
int t = tokens; t < context_window; ++t) {
71 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
72 memset(out, 0, (
size_t)aligned_embed_dim *
sizeof(
float));
79 const void *token_embeddings,
80 const float *pos_embeddings,
83 int aligned_embed_dim,
87 if (!token_ids || !token_embeddings || !output) {
91 int tokens = token_count;
95 if (tokens > context_window) {
96 tokens = context_window;
100 const uint8_t *base = (
const uint8_t *)token_embeddings;
102 for (
int t = 0; t < tokens; ++t) {
103 int id = token_ids[t];
108 const void *tok = base + (size_t)
id * row_bytes;
109 const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
110 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
114 if (add_pos && pos) {
115 for (
int d = 0; d < embed_dim; ++d) {
120 for (
int d = embed_dim; d < aligned_embed_dim; ++d) {
125 for (
int t = tokens; t < context_window; ++t) {
126 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
127 memset(out, 0, (
size_t)aligned_embed_dim *
sizeof(
float));
134 const void *token_embeddings,
135 const float *pos_embeddings,
138 int aligned_embed_dim,
142 if (!token_ids || !token_embeddings || !output) {
146 int tokens = token_count;
150 if (tokens > context_window) {
151 tokens = context_window;
155 const uint8_t *base = (
const uint8_t *)token_embeddings;
157 for (
int t = 0; t < tokens; ++t) {
158 int id = token_ids[t];
163 const void *tok = base + (size_t)
id * row_bytes;
164 const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
165 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
169 if (add_pos && pos) {
170 for (
int d = 0; d < embed_dim; ++d) {
175 for (
int d = embed_dim; d < aligned_embed_dim; ++d) {
180 for (
int t = tokens; t < context_window; ++t) {
181 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
182 memset(out, 0, (
size_t)aligned_embed_dim *
sizeof(
float));
189 const void *token_embeddings,
190 const float *pos_embeddings,
193 int aligned_embed_dim,
197 if (!token_ids || !token_embeddings || !output) {
201 int tokens = token_count;
205 if (tokens > context_window) {
206 tokens = context_window;
210 const uint8_t *base = (
const uint8_t *)token_embeddings;
212 for (
int t = 0; t < tokens; ++t) {
213 int id = token_ids[t];
218 const void *tok = base + (size_t)
id * row_bytes;
219 const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
220 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
224 if (add_pos && pos) {
225 for (
int d = 0; d < embed_dim; ++d) {
230 for (
int d = embed_dim; d < aligned_embed_dim; ++d) {
235 for (
int t = tokens; t < context_window; ++t) {
236 float *out = output + (size_t)t * (
size_t)aligned_embed_dim;
237 memset(out, 0, (
size_t)aligned_embed_dim *
sizeof(
float));
243 const float *d_output,
244 float *d_token_embeddings,
245 float *d_pos_embeddings,
248 int aligned_embed_dim,
252 if (!token_ids || !d_output || !d_token_embeddings) {
256 int tokens = token_count;
260 if (tokens > context_window) {
261 tokens = context_window;
264 for (
int t = 0; t < tokens; ++t) {
265 int id = token_ids[t];
270 const float *d_out = d_output + (size_t)t * (
size_t)aligned_embed_dim;
271 float *d_tok = d_token_embeddings + (size_t)
id * (
size_t)aligned_embed_dim;
272 float *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
274 for (
int d = 0; d < embed_dim; ++d) {
275 float grad = d_out[d];
277 if (add_pos && d_pos) {
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void dequant_q8_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
void embedding_forward_q6_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_backward(const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_forward_q8_0(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_forward(const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)