← Back to C-Kernel-Engine Docs Doxygen Source Documentation
embedding_kernels_bf16.c File Reference

Token/position embedding lookup kernels for BF16. More...

#include <stdint.h>
#include <string.h>
#include "bf16_utils.h"
#include "ckernel_engine.h"

Go to the source code of this file.

Functions

void embedding_backward_bf16 (const int32_t *token_ids, int token_count, const uint16_t *d_output, uint16_t *d_token_embeddings, uint16_t *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_bf16 (const int32_t *token_ids, int token_count, int vocab_size, const uint16_t *token_embeddings, const uint16_t *pos_embeddings, uint16_t *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 

Detailed Description

Token/position embedding lookup kernels for BF16.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Definition in file embedding_kernels_bf16.c.

Function Documentation

◆ embedding_backward_bf16()

void embedding_backward_bf16 ( const int32_t *  token_ids,
int  token_count,
const uint16_t *  d_output,
uint16_t *  d_token_embeddings,
uint16_t *  d_pos_embeddings,
int  vocab_size,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 72 of file embedding_kernels_bf16.c.

82 {
83  if (!token_ids || !d_output || !d_token_embeddings) {
84  return;
85  }
86 
87  int tokens = token_count;
88  if (tokens < 0) tokens = 0;
89  if (tokens > context_window) tokens = context_window;
90 
91  for (int t = 0; t < tokens; ++t) {
92  int id = token_ids[t];
93  if (id < 0 || id >= vocab_size) {
94  id = 0;
95  }
96 
97  const uint16_t *d_out = d_output + (size_t)t * (size_t)aligned_embed_dim;
98  uint16_t *d_tok = d_token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
99  uint16_t *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
100 
101  for (int d = 0; d < embed_dim; ++d) {
102  float grad = bf16_to_float(d_out[d]);
103 
104  float cur_tok = bf16_to_float(d_tok[d]);
105  d_tok[d] = float_to_bf16(cur_tok + grad);
106 
107  if (add_pos && d_pos) {
108  float cur_pos = bf16_to_float(d_pos[d]);
109  d_pos[d] = float_to_bf16(cur_pos + grad);
110  }
111  }
112  }
113 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
int vocab_size
Definition: true_bpe.h:185

References bf16_to_float(), float_to_bf16(), and vocab_size.

◆ embedding_forward_bf16()

void embedding_forward_bf16 ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const uint16_t *  token_embeddings,
const uint16_t *  pos_embeddings,
uint16_t *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 21 of file embedding_kernels_bf16.c.

31 {
32  if (!token_ids || !token_embeddings || !output) {
33  return;
34  }
35 
36  int tokens = token_count;
37  if (tokens < 0) tokens = 0;
38  if (tokens > context_window) tokens = context_window;
39 
40  for (int t = 0; t < tokens; ++t) {
41  int id = token_ids[t];
42  if (id < 0 || id >= vocab_size) {
43  id = 0;
44  }
45 
46  const uint16_t *tok = token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
47  const uint16_t *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
48  uint16_t *out = output + (size_t)t * (size_t)aligned_embed_dim;
49 
50  if (add_pos && pos) {
51  for (int d = 0; d < embed_dim; ++d) {
52  float v = bf16_to_float(tok[d]) + bf16_to_float(pos[d]);
53  out[d] = float_to_bf16(v);
54  }
55  } else {
56  for (int d = 0; d < embed_dim; ++d) {
57  out[d] = tok[d];
58  }
59  }
60 
61  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
62  out[d] = 0;
63  }
64  }
65 
66  for (int t = tokens; t < context_window; ++t) {
67  uint16_t *out = output + (size_t)t * (size_t)aligned_embed_dim;
68  memset(out, 0, (size_t)aligned_embed_dim * sizeof(uint16_t));
69  }
70 }

References bf16_to_float(), float_to_bf16(), and vocab_size.