← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels.c File Reference

RMSNorm forward/backward kernels with SIMD (SSE/AVX/AVX512) More...

#include <math.h>
#include <stddef.h>

Go to the source code of this file.

Functions

void rmsnorm_backward (const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
 
void rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 

Detailed Description

RMSNorm forward/backward kernels with SIMD (SSE/AVX/AVX512)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

Definition in file rmsnorm_kernels.c.

Function Documentation

◆ rmsnorm_backward()

void rmsnorm_backward ( const float *  d_output,
const float *  input,
const float *  gamma,
const float *  rstd_cache,
float *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim 
)

RMSNorm backward pass

Test:

test_rmsnorm.py::TestRMSNormBackward::test_backward_tokens

test_rmsnorm.py::TestRMSNormBackward::test_backward_single

test_parity.py::test_rmsnorm_backward_parity

Computes dX and dGamma given dY, X, gamma, and cached rstd. dX_i = rstd * (dY_i * gamma_i - x_hat_i * m) dGamma_i = sum_t (dY_i * x_hat_i)

After changes: make test && make llamacpp-parity-full

Definition at line 184 of file rmsnorm_kernels.c.

193 {
194  int T = tokens;
195  int D = d_model;
196  int aligned = aligned_embed_dim;
197 
198  // Zero parameter gradients
199 #if defined(__AVX512F__)
200  {
201  int d = 0;
202  for (; d + 16 <= D; d += 16) {
203  _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
204  }
205  for (; d < D; ++d) {
206  d_gamma[d] = 0.0f;
207  }
208  }
209 #elif defined(__AVX__)
210  {
211  int d = 0;
212  for (; d + 8 <= D; d += 8) {
213  _mm256_storeu_ps(&d_gamma[d], _mm256_setzero_ps());
214  }
215  for (; d < D; ++d) {
216  d_gamma[d] = 0.0f;
217  }
218  }
219 #else
220  for (int d = 0; d < D; ++d) {
221  d_gamma[d] = 0.0f;
222  }
223 #endif
224 
225  for (int t = 0; t < T; ++t) {
226  const float *x = input + (size_t)t * aligned;
227  const float *dY = d_output + (size_t)t * aligned;
228  float *dX = d_input + (size_t)t * aligned;
229 
230  float rstd = rstd_cache[t];
231 
232 #if defined(__AVX512F__)
233  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
234  __m512 rstd_vec = _mm512_set1_ps(rstd);
235  __m512 sum_vec = _mm512_setzero_ps();
236  int d = 0;
237 
238  for (; d + 16 <= D; d += 16) {
239  __m512 xv = _mm512_loadu_ps(&x[d]);
240  __m512 dyv = _mm512_loadu_ps(&dY[d]);
241  __m512 gv = _mm512_loadu_ps(&gamma[d]);
242  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
243  // sum += dY * gamma * x_hat
244  __m512 prod = _mm512_mul_ps(dyv, gv);
245  sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
246  }
247  float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
248 
249  // Handle remaining elements
250  for (; d < D; ++d) {
251  float x_hat = x[d] * rstd;
252  sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
253  }
254  float m = sum_dY_g_xhat / (float)D;
255 
256  // Compute dX and accumulate dGamma (vectorized)
257  __m512 m_vec = _mm512_set1_ps(m);
258  d = 0;
259  for (; d + 16 <= D; d += 16) {
260  __m512 xv = _mm512_loadu_ps(&x[d]);
261  __m512 dyv = _mm512_loadu_ps(&dY[d]);
262  __m512 gv = _mm512_loadu_ps(&gamma[d]);
263  __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
264 
265  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
266 
267  // dX = rstd * (dY * gamma - x_hat * m)
268  __m512 dy_g = _mm512_mul_ps(dyv, gv);
269  __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
270  __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
271  __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
272  _mm512_storeu_ps(&dX[d], dxv);
273 
274  // d_gamma += dY * x_hat
275  dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
276  _mm512_storeu_ps(&d_gamma[d], dgv);
277  }
278  // Handle remaining elements
279  for (; d < D; ++d) {
280  float x_hat = x[d] * rstd;
281  float dy = dY[d];
282  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
283  d_gamma[d] += dy * x_hat;
284  }
285 
286 #elif defined(__AVX__)
287  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
288  __m256 rstd_vec = _mm256_set1_ps(rstd);
289  __m256 sum_vec = _mm256_setzero_ps();
290  int d = 0;
291 
292  for (; d + 8 <= D; d += 8) {
293  __m256 xv = _mm256_loadu_ps(&x[d]);
294  __m256 dyv = _mm256_loadu_ps(&dY[d]);
295  __m256 gv = _mm256_loadu_ps(&gamma[d]);
296  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
297  // sum += dY * gamma * x_hat (no FMA, use mul + mul + add)
298  __m256 prod = _mm256_mul_ps(dyv, gv);
299  __m256 prod2 = _mm256_mul_ps(prod, x_hat);
300  sum_vec = _mm256_add_ps(sum_vec, prod2);
301  }
302  float sum_dY_g_xhat = hsum256_ps_rmsnorm(sum_vec);
303 
304  // Handle remaining elements
305  for (; d < D; ++d) {
306  float x_hat = x[d] * rstd;
307  sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
308  }
309  float m = sum_dY_g_xhat / (float)D;
310 
311  // Compute dX and accumulate dGamma (vectorized)
312  __m256 m_vec = _mm256_set1_ps(m);
313  d = 0;
314  for (; d + 8 <= D; d += 8) {
315  __m256 xv = _mm256_loadu_ps(&x[d]);
316  __m256 dyv = _mm256_loadu_ps(&dY[d]);
317  __m256 gv = _mm256_loadu_ps(&gamma[d]);
318  __m256 dgv = _mm256_loadu_ps(&d_gamma[d]);
319 
320  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
321 
322  // dX = rstd * (dY * gamma - x_hat * m)
323  __m256 dy_g = _mm256_mul_ps(dyv, gv);
324  __m256 xhat_m = _mm256_mul_ps(x_hat, m_vec);
325  __m256 diff = _mm256_sub_ps(dy_g, xhat_m);
326  __m256 dxv = _mm256_mul_ps(rstd_vec, diff);
327  _mm256_storeu_ps(&dX[d], dxv);
328 
329  // d_gamma += dY * x_hat
330  __m256 dy_xhat = _mm256_mul_ps(dyv, x_hat);
331  dgv = _mm256_add_ps(dgv, dy_xhat);
332  _mm256_storeu_ps(&d_gamma[d], dgv);
333  }
334  // Handle remaining elements
335  for (; d < D; ++d) {
336  float x_hat = x[d] * rstd;
337  float dy = dY[d];
338  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
339  d_gamma[d] += dy * x_hat;
340  }
341 
342 #else
343  // Scalar fallback
344  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
345  double sum_dY_g_xhat = 0.0;
346  for (int d = 0; d < D; ++d) {
347  float x_hat = x[d] * rstd;
348  sum_dY_g_xhat += (double)dY[d] * (double)gamma[d] * (double)x_hat;
349  }
350  float m = (float)(sum_dY_g_xhat / (double)D);
351 
352  // Compute dX and accumulate dGamma
353  for (int d = 0; d < D; ++d) {
354  float x_hat = x[d] * rstd;
355  float dy = dY[d];
356  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
357  d_gamma[d] += dy * x_hat;
358  }
359 #endif
360 
361  // Zero padding gradients (if any)
362  for (int d = D; d < aligned; ++d) {
363  dX[d] = 0.0f;
364  }
365  }
366 }

Referenced by ck_layer_backward_rmsnorm_swiglu(), rmsnorm_backward_int4(), and rmsnorm_backward_int8().

◆ rmsnorm_forward()

void rmsnorm_forward ( const float *  input,
const float *  gamma,
float *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

RMSNorm forward pass

Test:

test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens

test_rmsnorm.py::TestRMSNormForward::test_fp32_single

test_rmsnorm.py::TestRMSNormForward::test_perf_rolled

test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat

test_parity.py::test_rmsnorm_parity

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

After changes: make test && make llamacpp-parity-full

Definition at line 50 of file rmsnorm_kernels.c.

58 {
59  int T = tokens;
60  int D = d_model;
61  int aligned = aligned_embed_dim;
62 
63  for (int t = 0; t < T; ++t) {
64  const float *x = input + (size_t)t * aligned;
65  float *y = output + (size_t)t * aligned;
66 
67 #if defined(__AVX512F__)
68  // AVX-512: Process 16 floats at a time
69  __m512 sum_sq_vec = _mm512_setzero_ps();
70  int d = 0;
71 
72  // Vectorized sum of squares
73  for (; d + 16 <= D; d += 16) {
74  __m512 xv = _mm512_loadu_ps(&x[d]);
75  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
76  }
77  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
78 
79  // Handle remaining elements
80  for (; d < D; ++d) {
81  sum_sq += x[d] * x[d];
82  }
83 
84  float mean_sq = sum_sq / (float)D;
85  float rstd = 1.0f / sqrtf(mean_sq + eps);
86  if (rstd_cache) {
87  rstd_cache[t] = rstd;
88  }
89 
90  // Apply normalization and scale (vectorized)
91  __m512 rstd_vec = _mm512_set1_ps(rstd);
92  d = 0;
93  for (; d + 16 <= D; d += 16) {
94  __m512 xv = _mm512_loadu_ps(&x[d]);
95  __m512 gv = _mm512_loadu_ps(&gamma[d]);
96  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97  __m512 yv = _mm512_mul_ps(x_hat, gv);
98  _mm512_storeu_ps(&y[d], yv);
99  }
100  // Handle remaining elements
101  for (; d < D; ++d) {
102  y[d] = x[d] * rstd * gamma[d];
103  }
104 
105 #elif defined(__AVX__)
106  // AVX: Process 8 floats at a time
107  __m256 sum_sq_vec = _mm256_setzero_ps();
108  int d = 0;
109 
110  // Vectorized sum of squares (no FMA in AVX1, use mul + add)
111  for (; d + 8 <= D; d += 8) {
112  __m256 xv = _mm256_loadu_ps(&x[d]);
113  __m256 xv_sq = _mm256_mul_ps(xv, xv);
114  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
115  }
116  float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
117 
118  // Handle remaining elements
119  for (; d < D; ++d) {
120  sum_sq += x[d] * x[d];
121  }
122 
123  float mean_sq = sum_sq / (float)D;
124  float rstd = 1.0f / sqrtf(mean_sq + eps);
125  if (rstd_cache) {
126  rstd_cache[t] = rstd;
127  }
128 
129  // Apply normalization and scale (vectorized)
130  __m256 rstd_vec = _mm256_set1_ps(rstd);
131  d = 0;
132  for (; d + 8 <= D; d += 8) {
133  __m256 xv = _mm256_loadu_ps(&x[d]);
134  __m256 gv = _mm256_loadu_ps(&gamma[d]);
135  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136  __m256 yv = _mm256_mul_ps(x_hat, gv);
137  _mm256_storeu_ps(&y[d], yv);
138  }
139  // Handle remaining elements
140  for (; d < D; ++d) {
141  y[d] = x[d] * rstd * gamma[d];
142  }
143 
144 #else
145  // Scalar fallback
146  double sum_sq = 0.0;
147  for (int d = 0; d < D; ++d) {
148  double v = (double)x[d];
149  sum_sq += v * v;
150  }
151  double mean_sq = sum_sq / (double)D;
152  double r = sqrt(mean_sq + (double)eps);
153  float rstd = (float)(1.0 / r);
154  if (rstd_cache) {
155  rstd_cache[t] = rstd;
156  }
157 
158  // Apply normalization and scale
159  for (int d = 0; d < D; ++d) {
160  float x_hat = x[d] * rstd;
161  y[d] = x_hat * gamma[d];
162  }
163 #endif
164 
165  // Zero padding (if any)
166  for (int d = D; d < aligned; ++d) {
167  y[d] = 0.0f;
168  }
169  }
170 }

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), ck_test_rmsnorm(), mega_fused_attention_decode_q5_0(), mega_fused_attention_decode_q5_0_parallel_simd(), mega_fused_outproj_mlp_prefill(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qk_norm_forward(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), qwen2_0_5b_decode_layer_9_prefill(), rmsnorm_forward_int4(), and rmsnorm_forward_int8().