26 const int32_t *targets,
31 float *scratch_logits,
32 float *scratch_d_logits)
34 if (!logits || !targets || !d_logits || tokens <= 0 ||
vocab_size <= 0) {
35 if (loss_out) *loss_out = 0.0f;
38 if (!scratch_logits || !scratch_d_logits) {
39 if (loss_out) *loss_out = 0.0f;
43 const size_t count = (size_t)tokens * (
size_t)
vocab_size;
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
void softmax_cross_entropy_loss_bf16(const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits)