24 const uint16_t *token_embeddings,
25 const uint16_t *pos_embeddings,
28 int aligned_embed_dim,
32 if (!token_ids || !token_embeddings || !output) {
36 int tokens = token_count;
37 if (tokens < 0) tokens = 0;
38 if (tokens > context_window) tokens = context_window;
40 for (
int t = 0; t < tokens; ++t) {
41 int id = token_ids[t];
46 const uint16_t *tok = token_embeddings + (size_t)
id * (
size_t)aligned_embed_dim;
47 const uint16_t *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
48 uint16_t *out = output + (size_t)t * (
size_t)aligned_embed_dim;
51 for (
int d = 0; d < embed_dim; ++d) {
56 for (
int d = 0; d < embed_dim; ++d) {
61 for (
int d = embed_dim; d < aligned_embed_dim; ++d) {
66 for (
int t = tokens; t < context_window; ++t) {
67 uint16_t *out = output + (size_t)t * (
size_t)aligned_embed_dim;
68 memset(out, 0, (
size_t)aligned_embed_dim *
sizeof(uint16_t));
74 const uint16_t *d_output,
75 uint16_t *d_token_embeddings,
76 uint16_t *d_pos_embeddings,
79 int aligned_embed_dim,
83 if (!token_ids || !d_output || !d_token_embeddings) {
87 int tokens = token_count;
88 if (tokens < 0) tokens = 0;
89 if (tokens > context_window) tokens = context_window;
91 for (
int t = 0; t < tokens; ++t) {
92 int id = token_ids[t];
97 const uint16_t *d_out = d_output + (size_t)t * (
size_t)aligned_embed_dim;
98 uint16_t *d_tok = d_token_embeddings + (size_t)
id * (
size_t)aligned_embed_dim;
99 uint16_t *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (
size_t)aligned_embed_dim) : NULL;
101 for (
int d = 0; d < embed_dim; ++d) {
107 if (add_pos && d_pos) {
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
void embedding_forward_bf16(const int32_t *token_ids, int token_count, int vocab_size, const uint16_t *token_embeddings, const uint16_t *pos_embeddings, uint16_t *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void embedding_backward_bf16(const int32_t *token_ids, int token_count, const uint16_t *d_output, uint16_t *d_token_embeddings, uint16_t *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)