22 #pragma GCC diagnostic push
23 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
29 for (
size_t i = 0; i < count; ++i) {
30 dst[i] = (float)src[i];
36 int32_t q = (int32_t)lrintf(value);
39 }
else if (q < INT8_MIN) {
49 for (
size_t i = 0; i < count; ++i) {
64 int aligned_embed_dim,
67 float *scratch_output)
69 if (!input || !gamma || !output)
return;
70 if (!scratch_input || !scratch_output)
return;
72 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
76 tokens, d_model, aligned_embed_dim, eps);
87 const float *rstd_cache,
92 int aligned_embed_dim,
93 float *scratch_d_output,
95 float *scratch_d_input)
97 if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma)
return;
98 if (!scratch_d_output || !scratch_input || !scratch_d_input)
return;
100 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
106 for (
int d = 0; d < d_model; ++d) {
123 #pragma GCC diagnostic pop
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static int8_t clamp_int8(float value)
void rmsnorm_forward_int8(const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
static void convert_int8_to_float(const int8_t *src, float *dst, size_t count)
void rmsnorm_backward_int8(const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
static void convert_float_to_int8(const float *src, int8_t *dst, size_t count)