RMSNorm kernels for BF16 tensors. More...
Go to the source code of this file.
Functions | |
| void | rmsnorm_backward_bf16 (const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim) |
| void | rmsnorm_forward_bf16 (const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
RMSNorm kernels for BF16 tensors.
After changes: make test && make llamacpp-parity-full
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
Definition in file rmsnorm_kernels_bf16.c.
| void rmsnorm_backward_bf16 | ( | const uint16_t * | d_output, |
| const uint16_t * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| uint16_t * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim | ||
| ) |
Definition at line 113 of file rmsnorm_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
| void rmsnorm_forward_bf16 | ( | const uint16_t * | input, |
| const float * | gamma, | ||
| uint16_t * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
Definition at line 24 of file rmsnorm_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().