← Back to C-Kernel-Engine Docs Doxygen Source Documentation
loss_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file loss_kernels_bf16.c
3  * @brief Loss function kernels for BF16 tensors
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  */
14 
15 #include <stddef.h>
16 #include <stdint.h>
17 
18 #include "bf16_utils.h"
19 #include "ckernel_engine.h"
20 
21 /*
22  * BF16 softmax cross-entropy loss with caller-provided scratch buffers.
23  * scratch_logits, scratch_d_logits: each [tokens * vocab_size] floats
24  */
25 void softmax_cross_entropy_loss_bf16(const uint16_t *logits,
26  const int32_t *targets,
27  int tokens,
28  int vocab_size,
29  uint16_t *d_logits,
30  float *loss_out,
31  float *scratch_logits,
32  float *scratch_d_logits)
33 {
34  if (!logits || !targets || !d_logits || tokens <= 0 || vocab_size <= 0) {
35  if (loss_out) *loss_out = 0.0f;
36  return;
37  }
38  if (!scratch_logits || !scratch_d_logits) {
39  if (loss_out) *loss_out = 0.0f;
40  return;
41  }
42 
43  const size_t count = (size_t)tokens * (size_t)vocab_size;
44 
45  bf16_tensor_to_float(logits, scratch_logits, count);
46  softmax_cross_entropy_loss(scratch_logits, targets, tokens, vocab_size, scratch_d_logits, loss_out);
47  float_tensor_to_bf16(scratch_d_logits, d_logits, count);
48 }
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
Definition: loss_kernels.c:21
void softmax_cross_entropy_loss_bf16(const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits)
int vocab_size
Definition: true_bpe.h:185