22 const int32_t *targets,
28 if (!logits || !targets || !d_logits || tokens <= 0 ||
vocab_size <= 0) {
35 double total_loss = 0.0;
37 for (
int t = 0; t < tokens; ++t) {
38 const float *row = logits + (size_t)t * (
size_t)
vocab_size;
39 float *drow = d_logits + (size_t)t * (
size_t)
vocab_size;
40 int target = targets[t];
42 float max_logit = row[0];
44 if (row[v] > max_logit) {
51 float e = expf(row[v] - max_logit);
56 float inv_sum = 1.0f / (float)sum_exp;
62 total_loss += -logf(drow[target] + 1e-10f);
66 float scale = 1.0f / (float)tokens;
73 *loss_out = (float)(total_loss / (
double)tokens);
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)