20 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
21 #include <immintrin.h>
25 #if defined(__AVX__) && !defined(__AVX512F__)
26 static inline float hsum256_ps_rmsnorm(__m256 v) {
28 __m128 hi = _mm256_extractf128_ps(v, 1);
29 __m128 lo = _mm256_castps256_ps128(v);
30 __m128 sum128 = _mm_add_ps(lo, hi);
32 sum128 = _mm_hadd_ps(sum128, sum128);
33 sum128 = _mm_hadd_ps(sum128, sum128);
34 return _mm_cvtss_f32(sum128);
56 int aligned_embed_dim,
61 int aligned = aligned_embed_dim;
63 for (
int t = 0; t < T; ++t) {
64 const float *x = input + (size_t)t * aligned;
65 float *y = output + (size_t)t * aligned;
67 #if defined(__AVX512F__)
69 __m512 sum_sq_vec = _mm512_setzero_ps();
73 for (; d + 16 <= D; d += 16) {
74 __m512 xv = _mm512_loadu_ps(&x[d]);
75 sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
77 float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
81 sum_sq += x[d] * x[d];
84 float mean_sq = sum_sq / (float)D;
85 float rstd = 1.0f / sqrtf(mean_sq + eps);
91 __m512 rstd_vec = _mm512_set1_ps(rstd);
93 for (; d + 16 <= D; d += 16) {
94 __m512 xv = _mm512_loadu_ps(&x[d]);
95 __m512 gv = _mm512_loadu_ps(&gamma[d]);
96 __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97 __m512 yv = _mm512_mul_ps(x_hat, gv);
98 _mm512_storeu_ps(&y[d], yv);
102 y[d] = x[d] * rstd * gamma[d];
105 #elif defined(__AVX__)
107 __m256 sum_sq_vec = _mm256_setzero_ps();
111 for (; d + 8 <= D; d += 8) {
112 __m256 xv = _mm256_loadu_ps(&x[d]);
113 __m256 xv_sq = _mm256_mul_ps(xv, xv);
114 sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
116 float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
120 sum_sq += x[d] * x[d];
123 float mean_sq = sum_sq / (float)D;
124 float rstd = 1.0f / sqrtf(mean_sq + eps);
126 rstd_cache[t] = rstd;
130 __m256 rstd_vec = _mm256_set1_ps(rstd);
132 for (; d + 8 <= D; d += 8) {
133 __m256 xv = _mm256_loadu_ps(&x[d]);
134 __m256 gv = _mm256_loadu_ps(&gamma[d]);
135 __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136 __m256 yv = _mm256_mul_ps(x_hat, gv);
137 _mm256_storeu_ps(&y[d], yv);
141 y[d] = x[d] * rstd * gamma[d];
147 for (
int d = 0; d < D; ++d) {
148 double v = (double)x[d];
151 double mean_sq = sum_sq / (double)D;
152 double r = sqrt(mean_sq + (
double)eps);
153 float rstd = (float)(1.0 / r);
155 rstd_cache[t] = rstd;
159 for (
int d = 0; d < D; ++d) {
160 float x_hat = x[d] * rstd;
161 y[d] = x_hat * gamma[d];
166 for (
int d = D; d < aligned; ++d) {
187 const float *rstd_cache,
192 int aligned_embed_dim)
196 int aligned = aligned_embed_dim;
199 #if defined(__AVX512F__)
202 for (; d + 16 <= D; d += 16) {
203 _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
209 #elif defined(__AVX__)
212 for (; d + 8 <= D; d += 8) {
213 _mm256_storeu_ps(&d_gamma[d], _mm256_setzero_ps());
220 for (
int d = 0; d < D; ++d) {
225 for (
int t = 0; t < T; ++t) {
226 const float *x = input + (size_t)t * aligned;
227 const float *dY = d_output + (size_t)t * aligned;
228 float *dX = d_input + (size_t)t * aligned;
230 float rstd = rstd_cache[t];
232 #if defined(__AVX512F__)
234 __m512 rstd_vec = _mm512_set1_ps(rstd);
235 __m512 sum_vec = _mm512_setzero_ps();
238 for (; d + 16 <= D; d += 16) {
239 __m512 xv = _mm512_loadu_ps(&x[d]);
240 __m512 dyv = _mm512_loadu_ps(&dY[d]);
241 __m512 gv = _mm512_loadu_ps(&gamma[d]);
242 __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
244 __m512 prod = _mm512_mul_ps(dyv, gv);
245 sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
247 float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
251 float x_hat = x[d] * rstd;
252 sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
254 float m = sum_dY_g_xhat / (float)D;
257 __m512 m_vec = _mm512_set1_ps(m);
259 for (; d + 16 <= D; d += 16) {
260 __m512 xv = _mm512_loadu_ps(&x[d]);
261 __m512 dyv = _mm512_loadu_ps(&dY[d]);
262 __m512 gv = _mm512_loadu_ps(&gamma[d]);
263 __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
265 __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
268 __m512 dy_g = _mm512_mul_ps(dyv, gv);
269 __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
270 __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
271 __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
272 _mm512_storeu_ps(&dX[d], dxv);
275 dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
276 _mm512_storeu_ps(&d_gamma[d], dgv);
280 float x_hat = x[d] * rstd;
282 dX[d] = rstd * (dy * gamma[d] - x_hat * m);
283 d_gamma[d] += dy * x_hat;
286 #elif defined(__AVX__)
288 __m256 rstd_vec = _mm256_set1_ps(rstd);
289 __m256 sum_vec = _mm256_setzero_ps();
292 for (; d + 8 <= D; d += 8) {
293 __m256 xv = _mm256_loadu_ps(&x[d]);
294 __m256 dyv = _mm256_loadu_ps(&dY[d]);
295 __m256 gv = _mm256_loadu_ps(&gamma[d]);
296 __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
298 __m256 prod = _mm256_mul_ps(dyv, gv);
299 __m256 prod2 = _mm256_mul_ps(prod, x_hat);
300 sum_vec = _mm256_add_ps(sum_vec, prod2);
302 float sum_dY_g_xhat = hsum256_ps_rmsnorm(sum_vec);
306 float x_hat = x[d] * rstd;
307 sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
309 float m = sum_dY_g_xhat / (float)D;
312 __m256 m_vec = _mm256_set1_ps(m);
314 for (; d + 8 <= D; d += 8) {
315 __m256 xv = _mm256_loadu_ps(&x[d]);
316 __m256 dyv = _mm256_loadu_ps(&dY[d]);
317 __m256 gv = _mm256_loadu_ps(&gamma[d]);
318 __m256 dgv = _mm256_loadu_ps(&d_gamma[d]);
320 __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
323 __m256 dy_g = _mm256_mul_ps(dyv, gv);
324 __m256 xhat_m = _mm256_mul_ps(x_hat, m_vec);
325 __m256 diff = _mm256_sub_ps(dy_g, xhat_m);
326 __m256 dxv = _mm256_mul_ps(rstd_vec, diff);
327 _mm256_storeu_ps(&dX[d], dxv);
330 __m256 dy_xhat = _mm256_mul_ps(dyv, x_hat);
331 dgv = _mm256_add_ps(dgv, dy_xhat);
332 _mm256_storeu_ps(&d_gamma[d], dgv);
336 float x_hat = x[d] * rstd;
338 dX[d] = rstd * (dy * gamma[d] - x_hat * m);
339 d_gamma[d] += dy * x_hat;
345 double sum_dY_g_xhat = 0.0;
346 for (
int d = 0; d < D; ++d) {
347 float x_hat = x[d] * rstd;
348 sum_dY_g_xhat += (double)dY[d] * (
double)gamma[d] * (double)x_hat;
350 float m = (float)(sum_dY_g_xhat / (
double)D);
353 for (
int d = 0; d < D; ++d) {
354 float x_hat = x[d] * rstd;
356 dX[d] = rstd * (dy * gamma[d] - x_hat * m);
357 d_gamma[d] += dy * x_hat;
362 for (
int d = D; d < aligned; ++d) {
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)