183 if (!embed_weight || !tokens || num_tokens <= 0)
return;
194 float *hidden = malloc(num_tokens * (num_layers + 1) * embed_dim *
sizeof(
float));
196 fprintf(stderr,
"Failed to allocate hidden states\n");
201 float *q = malloc(num_heads * head_dim *
sizeof(
float));
202 float *k = malloc(num_kv_heads * head_dim *
sizeof(
float));
203 float *v = malloc(num_kv_heads * head_dim *
sizeof(
float));
204 float *attn = malloc(num_heads * head_dim *
sizeof(
float));
205 float *mlp = malloc(intermediate *
sizeof(
float));
207 if (!q || !k || !v || !attn || !mlp) {
208 fprintf(stderr,
"Failed to allocate temp buffers\n");
219 const float *ln1_gamma = NULL;
220 const float *ln2_gamma = NULL;
221 const float *wq = NULL, *wk = NULL, *wv = NULL, *wo = NULL;
222 const float *w1 = NULL, *w2 = NULL;
225 #pragma omp parallel for schedule(dynamic, 1)
226 for (
int t = 0; t < num_tokens; t++) {
227 float *h = hidden + t * (num_layers + 1) * embed_dim;
233 for (
int layer = 0; layer < num_layers; layer++) {
235 float *layer_out = h + embed_dim;
238 simple_rmsnorm(layer_in, ln1_gamma, layer_in, 1, embed_dim, 1e-6f);
241 gemm_nt(layer_in, wq, q, 1, num_heads * head_dim, embed_dim);
242 gemm_nt(layer_in, wk, k, 1, num_kv_heads * head_dim, embed_dim);
243 gemm_nt(layer_in, wv, v, 1, num_kv_heads * head_dim, embed_dim);
253 gemm_nt(attn, wo, layer_out, 1, embed_dim, num_heads * head_dim);
259 simple_rmsnorm(layer_in, ln2_gamma, layer_in, 1, embed_dim, 1e-6f);
262 gemm_nt(layer_in, w1, mlp, 1, 2 * intermediate, embed_dim);
263 silu(mlp, 2 * intermediate);
264 gemm_nt(mlp, w2, layer_out, 1, embed_dim, intermediate);
271 memcpy(hidden + t * (num_layers + 1) * embed_dim +
272 num_layers * embed_dim, h, embed_dim *
sizeof(
float));
276 float *final_out = malloc(num_tokens * embed_dim *
sizeof(
float));
278 simple_rmsnorm(hidden + num_layers * embed_dim, ln1_gamma, final_out,
279 num_tokens, embed_dim, 1e-6f);
static void residual_add(float *residual, float *addend, int n)
static void simple_embedding(const int32_t *tokens, int num_tokens, const float *weight, float *output, int vocab_size, int embed_dim)
#define MODEL_INTERMEDIATE
static void simple_attention(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int seq_len, int head_dim)
static void silu(float *x, int n)
static void gemm_nt(const float *input, const float *weight, float *output, int rows, int cols, int common)
#define MODEL_NUM_KV_HEADS
static void simple_rmsnorm(const float *input, const float *gamma, float *output, int tokens, int d_model, float eps)
static void apply_rope(float *x, int seq_len, int head_dim)