← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels.c
Go to the documentation of this file.
1 /**
2  * @file rmsnorm_kernels.c
3  * @brief RMSNorm forward/backward kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
15  */
16 
17 #include <math.h>
18 #include <stddef.h>
19 
20 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
21 #include <immintrin.h>
22 #endif
23 
24 /* AVX1 horizontal sum helper (no _mm256_reduce_add_ps in AVX1) */
25 #if defined(__AVX__) && !defined(__AVX512F__)
26 static inline float hsum256_ps_rmsnorm(__m256 v) {
27  // Sum upper and lower 128-bit lanes
28  __m128 hi = _mm256_extractf128_ps(v, 1);
29  __m128 lo = _mm256_castps256_ps128(v);
30  __m128 sum128 = _mm_add_ps(lo, hi);
31  // Horizontal add within 128-bit lane
32  sum128 = _mm_hadd_ps(sum128, sum128);
33  sum128 = _mm_hadd_ps(sum128, sum128);
34  return _mm_cvtss_f32(sum128);
35 }
36 #endif
37 
38 /**
39  * RMSNorm forward pass
40  * @test test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens
41  * @test test_rmsnorm.py::TestRMSNormForward::test_fp32_single
42  * @test test_rmsnorm.py::TestRMSNormForward::test_perf_rolled
43  * @test test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat
44  * @test test_parity.py::test_rmsnorm_parity
45  *
46  * RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
47  *
48  * After changes: make test && make llamacpp-parity-full
49  */
50 void rmsnorm_forward(const float *input,
51  const float *gamma,
52  float *output,
53  float *rstd_cache,
54  int tokens,
55  int d_model,
56  int aligned_embed_dim,
57  float eps)
58 {
59  int T = tokens;
60  int D = d_model;
61  int aligned = aligned_embed_dim;
62 
63  for (int t = 0; t < T; ++t) {
64  const float *x = input + (size_t)t * aligned;
65  float *y = output + (size_t)t * aligned;
66 
67 #if defined(__AVX512F__)
68  // AVX-512: Process 16 floats at a time
69  __m512 sum_sq_vec = _mm512_setzero_ps();
70  int d = 0;
71 
72  // Vectorized sum of squares
73  for (; d + 16 <= D; d += 16) {
74  __m512 xv = _mm512_loadu_ps(&x[d]);
75  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
76  }
77  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
78 
79  // Handle remaining elements
80  for (; d < D; ++d) {
81  sum_sq += x[d] * x[d];
82  }
83 
84  float mean_sq = sum_sq / (float)D;
85  float rstd = 1.0f / sqrtf(mean_sq + eps);
86  if (rstd_cache) {
87  rstd_cache[t] = rstd;
88  }
89 
90  // Apply normalization and scale (vectorized)
91  __m512 rstd_vec = _mm512_set1_ps(rstd);
92  d = 0;
93  for (; d + 16 <= D; d += 16) {
94  __m512 xv = _mm512_loadu_ps(&x[d]);
95  __m512 gv = _mm512_loadu_ps(&gamma[d]);
96  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97  __m512 yv = _mm512_mul_ps(x_hat, gv);
98  _mm512_storeu_ps(&y[d], yv);
99  }
100  // Handle remaining elements
101  for (; d < D; ++d) {
102  y[d] = x[d] * rstd * gamma[d];
103  }
104 
105 #elif defined(__AVX__)
106  // AVX: Process 8 floats at a time
107  __m256 sum_sq_vec = _mm256_setzero_ps();
108  int d = 0;
109 
110  // Vectorized sum of squares (no FMA in AVX1, use mul + add)
111  for (; d + 8 <= D; d += 8) {
112  __m256 xv = _mm256_loadu_ps(&x[d]);
113  __m256 xv_sq = _mm256_mul_ps(xv, xv);
114  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
115  }
116  float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
117 
118  // Handle remaining elements
119  for (; d < D; ++d) {
120  sum_sq += x[d] * x[d];
121  }
122 
123  float mean_sq = sum_sq / (float)D;
124  float rstd = 1.0f / sqrtf(mean_sq + eps);
125  if (rstd_cache) {
126  rstd_cache[t] = rstd;
127  }
128 
129  // Apply normalization and scale (vectorized)
130  __m256 rstd_vec = _mm256_set1_ps(rstd);
131  d = 0;
132  for (; d + 8 <= D; d += 8) {
133  __m256 xv = _mm256_loadu_ps(&x[d]);
134  __m256 gv = _mm256_loadu_ps(&gamma[d]);
135  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136  __m256 yv = _mm256_mul_ps(x_hat, gv);
137  _mm256_storeu_ps(&y[d], yv);
138  }
139  // Handle remaining elements
140  for (; d < D; ++d) {
141  y[d] = x[d] * rstd * gamma[d];
142  }
143 
144 #else
145  // Scalar fallback
146  double sum_sq = 0.0;
147  for (int d = 0; d < D; ++d) {
148  double v = (double)x[d];
149  sum_sq += v * v;
150  }
151  double mean_sq = sum_sq / (double)D;
152  double r = sqrt(mean_sq + (double)eps);
153  float rstd = (float)(1.0 / r);
154  if (rstd_cache) {
155  rstd_cache[t] = rstd;
156  }
157 
158  // Apply normalization and scale
159  for (int d = 0; d < D; ++d) {
160  float x_hat = x[d] * rstd;
161  y[d] = x_hat * gamma[d];
162  }
163 #endif
164 
165  // Zero padding (if any)
166  for (int d = D; d < aligned; ++d) {
167  y[d] = 0.0f;
168  }
169  }
170 }
171 
172 /**
173  * RMSNorm backward pass
174  * @test test_rmsnorm.py::TestRMSNormBackward::test_backward_tokens
175  * @test test_rmsnorm.py::TestRMSNormBackward::test_backward_single
176  * @test test_parity.py::test_rmsnorm_backward_parity
177  *
178  * Computes dX and dGamma given dY, X, gamma, and cached rstd.
179  * dX_i = rstd * (dY_i * gamma_i - x_hat_i * m)
180  * dGamma_i = sum_t (dY_i * x_hat_i)
181  *
182  * After changes: make test && make llamacpp-parity-full
183  */
184 void rmsnorm_backward(const float *d_output,
185  const float *input,
186  const float *gamma,
187  const float *rstd_cache,
188  float *d_input,
189  float *d_gamma,
190  int tokens,
191  int d_model,
192  int aligned_embed_dim)
193 {
194  int T = tokens;
195  int D = d_model;
196  int aligned = aligned_embed_dim;
197 
198  // Zero parameter gradients
199 #if defined(__AVX512F__)
200  {
201  int d = 0;
202  for (; d + 16 <= D; d += 16) {
203  _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
204  }
205  for (; d < D; ++d) {
206  d_gamma[d] = 0.0f;
207  }
208  }
209 #elif defined(__AVX__)
210  {
211  int d = 0;
212  for (; d + 8 <= D; d += 8) {
213  _mm256_storeu_ps(&d_gamma[d], _mm256_setzero_ps());
214  }
215  for (; d < D; ++d) {
216  d_gamma[d] = 0.0f;
217  }
218  }
219 #else
220  for (int d = 0; d < D; ++d) {
221  d_gamma[d] = 0.0f;
222  }
223 #endif
224 
225  for (int t = 0; t < T; ++t) {
226  const float *x = input + (size_t)t * aligned;
227  const float *dY = d_output + (size_t)t * aligned;
228  float *dX = d_input + (size_t)t * aligned;
229 
230  float rstd = rstd_cache[t];
231 
232 #if defined(__AVX512F__)
233  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
234  __m512 rstd_vec = _mm512_set1_ps(rstd);
235  __m512 sum_vec = _mm512_setzero_ps();
236  int d = 0;
237 
238  for (; d + 16 <= D; d += 16) {
239  __m512 xv = _mm512_loadu_ps(&x[d]);
240  __m512 dyv = _mm512_loadu_ps(&dY[d]);
241  __m512 gv = _mm512_loadu_ps(&gamma[d]);
242  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
243  // sum += dY * gamma * x_hat
244  __m512 prod = _mm512_mul_ps(dyv, gv);
245  sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
246  }
247  float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
248 
249  // Handle remaining elements
250  for (; d < D; ++d) {
251  float x_hat = x[d] * rstd;
252  sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
253  }
254  float m = sum_dY_g_xhat / (float)D;
255 
256  // Compute dX and accumulate dGamma (vectorized)
257  __m512 m_vec = _mm512_set1_ps(m);
258  d = 0;
259  for (; d + 16 <= D; d += 16) {
260  __m512 xv = _mm512_loadu_ps(&x[d]);
261  __m512 dyv = _mm512_loadu_ps(&dY[d]);
262  __m512 gv = _mm512_loadu_ps(&gamma[d]);
263  __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
264 
265  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
266 
267  // dX = rstd * (dY * gamma - x_hat * m)
268  __m512 dy_g = _mm512_mul_ps(dyv, gv);
269  __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
270  __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
271  __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
272  _mm512_storeu_ps(&dX[d], dxv);
273 
274  // d_gamma += dY * x_hat
275  dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
276  _mm512_storeu_ps(&d_gamma[d], dgv);
277  }
278  // Handle remaining elements
279  for (; d < D; ++d) {
280  float x_hat = x[d] * rstd;
281  float dy = dY[d];
282  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
283  d_gamma[d] += dy * x_hat;
284  }
285 
286 #elif defined(__AVX__)
287  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
288  __m256 rstd_vec = _mm256_set1_ps(rstd);
289  __m256 sum_vec = _mm256_setzero_ps();
290  int d = 0;
291 
292  for (; d + 8 <= D; d += 8) {
293  __m256 xv = _mm256_loadu_ps(&x[d]);
294  __m256 dyv = _mm256_loadu_ps(&dY[d]);
295  __m256 gv = _mm256_loadu_ps(&gamma[d]);
296  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
297  // sum += dY * gamma * x_hat (no FMA, use mul + mul + add)
298  __m256 prod = _mm256_mul_ps(dyv, gv);
299  __m256 prod2 = _mm256_mul_ps(prod, x_hat);
300  sum_vec = _mm256_add_ps(sum_vec, prod2);
301  }
302  float sum_dY_g_xhat = hsum256_ps_rmsnorm(sum_vec);
303 
304  // Handle remaining elements
305  for (; d < D; ++d) {
306  float x_hat = x[d] * rstd;
307  sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
308  }
309  float m = sum_dY_g_xhat / (float)D;
310 
311  // Compute dX and accumulate dGamma (vectorized)
312  __m256 m_vec = _mm256_set1_ps(m);
313  d = 0;
314  for (; d + 8 <= D; d += 8) {
315  __m256 xv = _mm256_loadu_ps(&x[d]);
316  __m256 dyv = _mm256_loadu_ps(&dY[d]);
317  __m256 gv = _mm256_loadu_ps(&gamma[d]);
318  __m256 dgv = _mm256_loadu_ps(&d_gamma[d]);
319 
320  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
321 
322  // dX = rstd * (dY * gamma - x_hat * m)
323  __m256 dy_g = _mm256_mul_ps(dyv, gv);
324  __m256 xhat_m = _mm256_mul_ps(x_hat, m_vec);
325  __m256 diff = _mm256_sub_ps(dy_g, xhat_m);
326  __m256 dxv = _mm256_mul_ps(rstd_vec, diff);
327  _mm256_storeu_ps(&dX[d], dxv);
328 
329  // d_gamma += dY * x_hat
330  __m256 dy_xhat = _mm256_mul_ps(dyv, x_hat);
331  dgv = _mm256_add_ps(dgv, dy_xhat);
332  _mm256_storeu_ps(&d_gamma[d], dgv);
333  }
334  // Handle remaining elements
335  for (; d < D; ++d) {
336  float x_hat = x[d] * rstd;
337  float dy = dY[d];
338  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
339  d_gamma[d] += dy * x_hat;
340  }
341 
342 #else
343  // Scalar fallback
344  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
345  double sum_dY_g_xhat = 0.0;
346  for (int d = 0; d < D; ++d) {
347  float x_hat = x[d] * rstd;
348  sum_dY_g_xhat += (double)dY[d] * (double)gamma[d] * (double)x_hat;
349  }
350  float m = (float)(sum_dY_g_xhat / (double)D);
351 
352  // Compute dX and accumulate dGamma
353  for (int d = 0; d < D; ++d) {
354  float x_hat = x[d] * rstd;
355  float dy = dY[d];
356  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
357  d_gamma[d] += dy * x_hat;
358  }
359 #endif
360 
361  // Zero padding gradients (if any)
362  for (int d = D; d < aligned; ++d) {
363  dX[d] = 0.0f;
364  }
365  }
366 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)