← Back to C-Kernel-Engine Docs Doxygen Source Documentation
loss_kernels.c
Go to the documentation of this file.
1 /**
2  * @file loss_kernels.c
3  * @brief Loss function kernels (cross-entropy, etc.)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Cross-entropy: L = -log(softmax(logits)[target])
15  */
16 
17 #include "ckernel_engine.h"
18 
19 #include <math.h>
20 
21 void softmax_cross_entropy_loss(const float *logits,
22  const int32_t *targets,
23  int tokens,
24  int vocab_size,
25  float *d_logits,
26  float *loss_out)
27 {
28  if (!logits || !targets || !d_logits || tokens <= 0 || vocab_size <= 0) {
29  if (loss_out) {
30  *loss_out = 0.0f;
31  }
32  return;
33  }
34 
35  double total_loss = 0.0;
36 
37  for (int t = 0; t < tokens; ++t) {
38  const float *row = logits + (size_t)t * (size_t)vocab_size;
39  float *drow = d_logits + (size_t)t * (size_t)vocab_size;
40  int target = targets[t];
41 
42  float max_logit = row[0];
43  for (int v = 1; v < vocab_size; ++v) {
44  if (row[v] > max_logit) {
45  max_logit = row[v];
46  }
47  }
48 
49  double sum_exp = 0.0;
50  for (int v = 0; v < vocab_size; ++v) {
51  float e = expf(row[v] - max_logit);
52  drow[v] = e;
53  sum_exp += e;
54  }
55 
56  float inv_sum = 1.0f / (float)sum_exp;
57  for (int v = 0; v < vocab_size; ++v) {
58  drow[v] *= inv_sum;
59  }
60 
61  if (target >= 0 && target < vocab_size) {
62  total_loss += -logf(drow[target] + 1e-10f);
63  drow[target] -= 1.0f;
64  }
65 
66  float scale = 1.0f / (float)tokens;
67  for (int v = 0; v < vocab_size; ++v) {
68  drow[v] *= scale;
69  }
70  }
71 
72  if (loss_out) {
73  *loss_out = (float)(total_loss / (double)tokens);
74  }
75 }
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
Definition: loss_kernels.c:21
int vocab_size
Definition: true_bpe.h:185