30 int aligned_embed_dim,
35 int aligned = aligned_embed_dim;
37 for (
int t = 0; t < T; ++t) {
38 const uint16_t *x_bf16 = input + (size_t)t * aligned;
39 float *rstd_ptr = rstd_cache ? (rstd_cache + t) : NULL;
40 uint16_t *out_bf16 = output + (size_t)t * aligned;
42 #if defined(__AVX512F__)
44 __m512 sum_sq_vec = _mm512_setzero_ps();
48 for (; d + 16 <= D; d += 16) {
49 __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
50 sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
52 float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
60 float mean_sq = sum_sq / (float)D;
61 float rstd = 1.0f / sqrtf(mean_sq + eps);
67 __m512 rstd_vec = _mm512_set1_ps(rstd);
69 for (; d + 16 <= D; d += 16) {
70 __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
71 __m512 gv = _mm512_loadu_ps(&gamma[d]);
72 __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
73 __m512 yv = _mm512_mul_ps(x_hat, gv);
74 fp32_cvt_storeu_bf16(&out_bf16[d], yv);
79 float y = x * rstd * gamma[d];
86 for (
int d = 0; d < D; ++d) {
88 sum_sq += (double)x * (
double)x;
90 double mean_sq = sum_sq / (double)D;
91 double r = sqrt(mean_sq + (
double)eps);
92 float rstd = (float)(1.0 / r);
97 for (
int d = 0; d < D; ++d) {
99 float x_hat = x * rstd;
100 float y = x_hat * gamma[d];
106 for (
int d = D; d < aligned; ++d) {
114 const uint16_t *input,
116 const float *rstd_cache,
121 int aligned_embed_dim)
125 int aligned = aligned_embed_dim;
127 if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) {
132 #if defined(__AVX512F__)
135 for (; d + 16 <= D; d += 16) {
136 _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
143 for (
int d = 0; d < D; ++d) {
148 for (
int t = 0; t < T; ++t) {
149 const uint16_t *x_bf16 = input + (size_t)t * aligned;
150 const uint16_t *dY_bf16 = d_output + (size_t)t * aligned;
151 uint16_t *dX_bf16 = d_input + (size_t)t * aligned;
152 float rstd = rstd_cache[t];
154 #if defined(__AVX512F__)
156 __m512 rstd_vec = _mm512_set1_ps(rstd);
157 __m512 sum_vec = _mm512_setzero_ps();
160 for (; d + 16 <= D; d += 16) {
161 __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
162 __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
163 __m512 gv = _mm512_loadu_ps(&gamma[d]);
164 __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
166 __m512 prod = _mm512_mul_ps(dyv, gv);
167 sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
169 float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
174 float x_hat = x * rstd;
176 sum_dY_g_xhat += dy * gamma[d] * x_hat;
178 float m = sum_dY_g_xhat / (float)D;
181 __m512 m_vec = _mm512_set1_ps(m);
183 for (; d + 16 <= D; d += 16) {
184 __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
185 __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
186 __m512 gv = _mm512_loadu_ps(&gamma[d]);
187 __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
189 __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
192 __m512 dy_g = _mm512_mul_ps(dyv, gv);
193 __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
194 __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
195 __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
196 fp32_cvt_storeu_bf16(&dX_bf16[d], dxv);
199 dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
200 _mm512_storeu_ps(&d_gamma[d], dgv);
205 float x_hat = x * rstd;
207 float dx = rstd * (dy * gamma[d] - x_hat * m);
209 d_gamma[d] += dy * x_hat;
214 double sum_dY_g_xhat = 0.0;
215 for (
int d = 0; d < D; ++d) {
217 float x_hat = x * rstd;
219 sum_dY_g_xhat += (double)dy * (
double)gamma[d] * (double)x_hat;
221 float m = (float)(sum_dY_g_xhat / (
double)D);
223 for (
int d = 0; d < D; ++d) {
225 float x_hat = x * rstd;
227 float dx = rstd * (dy * gamma[d] - x_hat * m);
229 d_gamma[d] += dy * x_hat;
234 for (
int d = D; d < aligned; ++d) {
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
void rmsnorm_backward_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void rmsnorm_forward_bf16(const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)