Loss function kernels for BF16 tensors. More...
Go to the source code of this file.
Functions | |
| void | softmax_cross_entropy_loss_bf16 (const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits) |
Loss function kernels for BF16 tensors.
After changes: make test && make llamacpp-parity-full
Definition in file loss_kernels_bf16.c.
| void softmax_cross_entropy_loss_bf16 | ( | const uint16_t * | logits, |
| const int32_t * | targets, | ||
| int | tokens, | ||
| int | vocab_size, | ||
| uint16_t * | d_logits, | ||
| float * | loss_out, | ||
| float * | scratch_logits, | ||
| float * | scratch_d_logits | ||
| ) |
Definition at line 25 of file loss_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), softmax_cross_entropy_loss(), and vocab_size.