← Back to C-Kernel-Engine Docs Doxygen Source Documentation
layernorm_kernels.c
Go to the documentation of this file.
1 /**
2  * @file layernorm_kernels.c
3  * @brief LayerNorm forward/backward kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * LayerNorm: y = gamma * (x - mean) / sqrt(var + eps) + beta
15  */
16 
17 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
18 #include <immintrin.h>
19 #endif
20 #include <math.h>
21 
22 static inline void zero_layernorm_padding(float *out_ptr,
23  int d_model,
24  int aligned_embed_dim)
25 {
26  for (int idx = d_model; idx < aligned_embed_dim; ++idx) {
27  out_ptr[idx] = 0.0f;
28  }
29 }
30 
31 void layernorm_naive_serial_matched_precision(const float *input,
32  const float *gamma,
33  const float *beta,
34  float *output,
35  float *mean_cache,
36  float *rstd_cache,
37  int tokens, int d_model, float eps);
38 
39 #if defined(__AVX2__) || defined(__AVX__)
40 static inline float hsum256_ps(__m256 v)
41 {
42  __m128 low = _mm256_castps256_ps128(v);
43  __m128 high = _mm256_extractf128_ps(v, 1);
44  __m128 sum = _mm_add_ps(low, high);
45  sum = _mm_hadd_ps(sum, sum);
46  sum = _mm_hadd_ps(sum, sum);
47  return _mm_cvtss_f32(sum);
48 }
49 #endif
50 // Naive serial LayerNorm implementation (forward only), copied from C-Transformer.
51 void layernorm_naive_serial(const float *input,
52  const float *gamma,
53  const float *beta,
54  float *output,
55  float *mean_cache,
56  float *rstd_cache,
57  int tokens, int d_model, int aligned_embed_dim,
58  float eps)
59 {
60  for (int t = 0; t < tokens; ++t) {
61  const float *in_ptr = input + t * aligned_embed_dim;
62  float *out_ptr = output + t * aligned_embed_dim;
63 
64  float sum_val = 0.0f;
65  for (int i = 0; i < d_model; ++i) {
66  sum_val += in_ptr[i];
67  }
68  float mean = sum_val / (float)d_model;
69 
70  float sum_sq_diff = 0.0f;
71  for (int i = 0; i < d_model; ++i) {
72  float diff = in_ptr[i] - mean;
73  sum_sq_diff += diff * diff;
74  }
75  float variance = sum_sq_diff / (float)d_model + eps;
76 
77  double var_double = (double)variance;
78  float inv_std = (float)(1.0 / sqrt(var_double));
79 
80  for (int i = 0; i < d_model; ++i) {
81  float normalized_val = (in_ptr[i] - mean) * inv_std;
82  out_ptr[i] = normalized_val * gamma[i] + beta[i];
83  }
84 
85  if (mean_cache) {
86  mean_cache[t] = mean;
87  }
88  if (rstd_cache) {
89  rstd_cache[t] = inv_std;
90  }
91  /* Keep aligned padding quiet so future GEMMs see deterministic memory. */
92  if (aligned_embed_dim > d_model) {
93  /* Keep padded lanes zeroed so subsequent GEMMs never read stale data. */
94  for (int i = d_model; i < aligned_embed_dim; ++i) {
95  out_ptr[i] = 0.0f;
96  }
97  }
98  }
99 }
100 
101 #if defined(__AVX512F__)
102 // AVX-512 rolled slice kernel, copied from C-Transformer (model-agnostic).
103 static void layernorm_forward_rolled_slice_avx512(const float *__restrict input_slice_base,
104  const float *__restrict gamma,
105  const float *__restrict beta,
106  float *__restrict output_slice_base,
107  float *__restrict mean_cache_slice,
108  float *__restrict rstd_cache_slice,
109  int num_tokens_in_slice,
110  int d_model,
111  int aligned_embed_dim,
112  float eps)
113 {
114  for (int t = 0; t < num_tokens_in_slice; ++t) {
115  const float *in_ptr_token = input_slice_base + t * aligned_embed_dim;
116  float *out_ptr_token = output_slice_base + t * aligned_embed_dim;
117 
118  __m512 acc_sum_vec = _mm512_setzero_ps();
119  int j = 0;
120  for (; j <= d_model - 16; j += 16) {
121  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
122  __m512 v = _mm512_load_ps(in_ptr_token + j);
123  acc_sum_vec = _mm512_add_ps(acc_sum_vec, v);
124  }
125  float mean = _mm512_reduce_add_ps(acc_sum_vec);
126  for (; j < d_model; ++j) {
127  mean += in_ptr_token[j];
128  }
129  mean /= (float)d_model;
130  __m512 mean_vec = _mm512_set1_ps(mean);
131 
132  __m512 acc_var_vec = _mm512_setzero_ps();
133  j = 0;
134  for (; j <= d_model - 16; j += 16) {
135  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
136  __m512 v = _mm512_load_ps(in_ptr_token + j);
137  __m512 diff = _mm512_sub_ps(v, mean_vec);
138  acc_var_vec = _mm512_fmadd_ps(diff, diff, acc_var_vec);
139  }
140  float var = _mm512_reduce_add_ps(acc_var_vec);
141  for (; j < d_model; ++j) {
142  float diff = in_ptr_token[j] - mean;
143  var += diff * diff;
144  }
145  var = var / (float)d_model + eps;
146  double var_double = (double)var;
147  float inv_std = (float)(1.0 / sqrt(var_double));
148  __m512 inv_std_vec = _mm512_set1_ps(inv_std);
149 
150  if (mean_cache_slice) {
151  mean_cache_slice[t] = mean;
152  }
153  if (rstd_cache_slice) {
154  rstd_cache_slice[t] = inv_std;
155  }
156 
157  j = 0;
158  for (; j <= d_model - 16; j += 16) {
159  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
160  _mm_prefetch((const char *)(gamma + j + 128), _MM_HINT_T0);
161  _mm_prefetch((const char *)(beta + j + 128), _MM_HINT_T0);
162 
163  __m512 v = _mm512_load_ps(in_ptr_token + j);
164  __m512 g = _mm512_load_ps(gamma + j);
165  __m512 b = _mm512_load_ps(beta + j);
166 
167  __m512 n = _mm512_mul_ps(_mm512_sub_ps(v, mean_vec), inv_std_vec);
168  __m512 o = _mm512_fmadd_ps(n, g, b);
169 
170  _mm512_store_ps(out_ptr_token + j, o);
171  }
172  for (; j < d_model; ++j) {
173  float normed = (in_ptr_token[j] - mean) * inv_std;
174  out_ptr_token[j] = normed * gamma[j] + beta[j];
175  }
176 
177  if (aligned_embed_dim > d_model) {
178  /* Keep the padded lanes zeroed so later GEMMs see deterministic memory. */
179  zero_layernorm_padding(out_ptr_token, d_model, aligned_embed_dim);
180  }
181  }
182 }
183 #elif defined(__AVX2__) || defined(__AVX__)
184 // AVX/AVX2 rolled slice kernel (8-float vectors).
185 static void layernorm_forward_rolled_slice_avx256(const float *__restrict input_slice_base,
186  const float *__restrict gamma,
187  const float *__restrict beta,
188  float *__restrict output_slice_base,
189  float *__restrict mean_cache_slice,
190  float *__restrict rstd_cache_slice,
191  int num_tokens_in_slice,
192  int d_model,
193  int aligned_embed_dim,
194  float eps)
195 {
196  for (int t = 0; t < num_tokens_in_slice; ++t) {
197  const float *in_ptr_token = input_slice_base + t * aligned_embed_dim;
198  float *out_ptr_token = output_slice_base + t * aligned_embed_dim;
199 
200  __m256 acc_sum_vec = _mm256_setzero_ps();
201  int j = 0;
202  for (; j <= d_model - 8; j += 8) {
203  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
204  __m256 v = _mm256_load_ps(in_ptr_token + j);
205  acc_sum_vec = _mm256_add_ps(acc_sum_vec, v);
206  }
207  float mean = hsum256_ps(acc_sum_vec);
208  for (; j < d_model; ++j) {
209  mean += in_ptr_token[j];
210  }
211  mean /= (float)d_model;
212  __m256 mean_vec = _mm256_set1_ps(mean);
213 
214  __m256 acc_var_vec = _mm256_setzero_ps();
215  j = 0;
216  for (; j <= d_model - 8; j += 8) {
217  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
218  __m256 v = _mm256_load_ps(in_ptr_token + j);
219  __m256 diff = _mm256_sub_ps(v, mean_vec);
220 #if defined(__FMA__)
221  acc_var_vec = _mm256_fmadd_ps(diff, diff, acc_var_vec);
222 #else
223  acc_var_vec = _mm256_add_ps(acc_var_vec, _mm256_mul_ps(diff, diff));
224 #endif
225  }
226  float var = hsum256_ps(acc_var_vec);
227  for (; j < d_model; ++j) {
228  float diff = in_ptr_token[j] - mean;
229  var += diff * diff;
230  }
231  var = var / (float)d_model + eps;
232  double var_double = (double)var;
233  float inv_std = (float)(1.0 / sqrt(var_double));
234  __m256 inv_std_vec = _mm256_set1_ps(inv_std);
235 
236  if (mean_cache_slice) {
237  mean_cache_slice[t] = mean;
238  }
239  if (rstd_cache_slice) {
240  rstd_cache_slice[t] = inv_std;
241  }
242 
243  j = 0;
244  for (; j <= d_model - 8; j += 8) {
245  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
246  _mm_prefetch((const char *)(gamma + j + 128), _MM_HINT_T0);
247  _mm_prefetch((const char *)(beta + j + 128), _MM_HINT_T0);
248 
249  __m256 v = _mm256_load_ps(in_ptr_token + j);
250  __m256 g = _mm256_load_ps(gamma + j);
251  __m256 b = _mm256_load_ps(beta + j);
252 
253  __m256 n = _mm256_mul_ps(_mm256_sub_ps(v, mean_vec), inv_std_vec);
254 #if defined(__FMA__)
255  __m256 o = _mm256_fmadd_ps(n, g, b);
256 #else
257  __m256 o = _mm256_add_ps(_mm256_mul_ps(n, g), b);
258 #endif
259 
260  _mm256_store_ps(out_ptr_token + j, o);
261  }
262  for (; j < d_model; ++j) {
263  float normed = (in_ptr_token[j] - mean) * inv_std;
264  out_ptr_token[j] = normed * gamma[j] + beta[j];
265  }
266 
267  if (aligned_embed_dim > d_model) {
268  zero_layernorm_padding(out_ptr_token, d_model, aligned_embed_dim);
269  }
270  }
271 }
272 #endif
273 
274 void layernorm_forward_rolled_slice(const float *__restrict input_slice_base,
275  const float *__restrict gamma,
276  const float *__restrict beta,
277  float *__restrict output_slice_base,
278  float *__restrict mean_cache_slice,
279  float *__restrict rstd_cache_slice,
280  int num_tokens_in_slice,
281  int d_model,
282  int aligned_embed_dim,
283  float eps)
284 {
285 #if defined(__AVX512F__)
286  layernorm_forward_rolled_slice_avx512(input_slice_base, gamma, beta,
287  output_slice_base, mean_cache_slice, rstd_cache_slice,
288  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
289 #elif defined(__AVX2__) || defined(__AVX__)
290  layernorm_forward_rolled_slice_avx256(input_slice_base, gamma, beta,
291  output_slice_base, mean_cache_slice, rstd_cache_slice,
292  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
293 #else
294  layernorm_naive_serial(input_slice_base, gamma, beta,
295  output_slice_base, mean_cache_slice, rstd_cache_slice,
296  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
297 #endif
298 }
299 
300 #if defined(__AVX512F__)
301 // AVX-512 unrolled slice kernel, copied from C-Transformer (model-agnostic).
302 static void layernorm_forward_unrolled_slice_avx512(const float *__restrict input_slice_base,
303  const float *__restrict gamma,
304  const float *__restrict beta,
305  float *__restrict output_slice_base,
306  float *__restrict mean_cache_slice,
307  float *__restrict rstd_cache_slice,
308  int num_tokens_in_slice,
309  int d_model,
310  float eps)
311 {
312  for (int t = 0; t < num_tokens_in_slice; ++t) {
313  const float *in_ptr_token = input_slice_base + t * d_model;
314  float *out_ptr_token = output_slice_base + t * d_model;
315 
316  __m512 acc0 = _mm512_setzero_ps();
317  __m512 acc1 = _mm512_setzero_ps();
318  __m512 acc2 = _mm512_setzero_ps();
319  __m512 acc3 = _mm512_setzero_ps();
320 
321  int j = 0;
322  int unroll_factor_floats = 64;
323 
324  for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
325  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
326 
327  __m512 v0 = _mm512_load_ps(in_ptr_token + j);
328  __m512 v1 = _mm512_load_ps(in_ptr_token + j + 16);
329  __m512 v2 = _mm512_load_ps(in_ptr_token + j + 32);
330  __m512 v3 = _mm512_load_ps(in_ptr_token + j + 48);
331 
332  acc0 = _mm512_add_ps(acc0, v0);
333  acc1 = _mm512_add_ps(acc1, v1);
334  acc2 = _mm512_add_ps(acc2, v2);
335  acc3 = _mm512_add_ps(acc3, v3);
336  }
337  __m512 acc_sum = _mm512_add_ps(_mm512_add_ps(acc0, acc1),
338  _mm512_add_ps(acc2, acc3));
339  float mean = _mm512_reduce_add_ps(acc_sum);
340 
341  for (; j < d_model; ++j) {
342  mean += in_ptr_token[j];
343  }
344  mean /= (float)d_model;
345  __m512 mean_vec = _mm512_set1_ps(mean);
346 
347  acc0 = _mm512_setzero_ps();
348  acc1 = _mm512_setzero_ps();
349  acc2 = _mm512_setzero_ps();
350  acc3 = _mm512_setzero_ps();
351 
352  j = 0;
353  for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
354  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
355 
356  __m512 v0 = _mm512_load_ps(in_ptr_token + j);
357  __m512 v1 = _mm512_load_ps(in_ptr_token + j + 16);
358  __m512 v2 = _mm512_load_ps(in_ptr_token + j + 32);
359  __m512 v3 = _mm512_load_ps(in_ptr_token + j + 48);
360 
361  __m512 d0 = _mm512_sub_ps(v0, mean_vec);
362  __m512 d1 = _mm512_sub_ps(v1, mean_vec);
363  __m512 d2 = _mm512_sub_ps(v2, mean_vec);
364  __m512 d3 = _mm512_sub_ps(v3, mean_vec);
365 
366  acc0 = _mm512_fmadd_ps(d0, d0, acc0);
367  acc1 = _mm512_fmadd_ps(d1, d1, acc1);
368  acc2 = _mm512_fmadd_ps(d2, d2, acc2);
369  acc3 = _mm512_fmadd_ps(d3, d3, acc3);
370  }
371  acc_sum = _mm512_add_ps(_mm512_add_ps(acc0, acc1),
372  _mm512_add_ps(acc2, acc3));
373  float var = _mm512_reduce_add_ps(acc_sum);
374 
375  for (; j < d_model; ++j) {
376  float diff = in_ptr_token[j] - mean;
377  var += diff * diff;
378  }
379  var = var / (float)d_model + eps;
380  double var_double = (double)var;
381  float inv_std = (float)(1.0 / sqrt(var_double));
382  __m512 inv_std_vec = _mm512_set1_ps(inv_std);
383 
384  if (mean_cache_slice) {
385  mean_cache_slice[t] = mean;
386  }
387  if (rstd_cache_slice) {
388  rstd_cache_slice[t] = inv_std;
389  }
390 
391  j = 0;
392  for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
393  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
394  _mm_prefetch((const char *)(gamma + j + 128), _MM_HINT_T0);
395  _mm_prefetch((const char *)(beta + j + 128), _MM_HINT_T0);
396 
397  __m512 v0 = _mm512_load_ps(in_ptr_token + j);
398  __m512 v1 = _mm512_load_ps(in_ptr_token + j + 16);
399  __m512 v2 = _mm512_load_ps(in_ptr_token + j + 32);
400  __m512 v3 = _mm512_load_ps(in_ptr_token + j + 48);
401 
402  __m512 g0 = _mm512_load_ps(gamma + j);
403  __m512 g1 = _mm512_load_ps(gamma + j + 16);
404  __m512 g2 = _mm512_load_ps(gamma + j + 32);
405  __m512 g3 = _mm512_load_ps(gamma + j + 48);
406 
407  __m512 b0 = _mm512_load_ps(beta + j);
408  __m512 b1 = _mm512_load_ps(beta + j + 16);
409  __m512 b2 = _mm512_load_ps(beta + j + 32);
410  __m512 b3 = _mm512_load_ps(beta + j + 48);
411 
412  __m512 n0 = _mm512_mul_ps(_mm512_sub_ps(v0, mean_vec), inv_std_vec);
413  __m512 n1 = _mm512_mul_ps(_mm512_sub_ps(v1, mean_vec), inv_std_vec);
414  __m512 n2 = _mm512_mul_ps(_mm512_sub_ps(v2, mean_vec), inv_std_vec);
415  __m512 n3 = _mm512_mul_ps(_mm512_sub_ps(v3, mean_vec), inv_std_vec);
416 
417  __m512 o0 = _mm512_fmadd_ps(n0, g0, b0);
418  __m512 o1 = _mm512_fmadd_ps(n1, g1, b1);
419  __m512 o2 = _mm512_fmadd_ps(n2, g2, b2);
420  __m512 o3 = _mm512_fmadd_ps(n3, g3, b3);
421 
422  _mm512_store_ps(out_ptr_token + j, o0);
423  _mm512_store_ps(out_ptr_token + j + 16, o1);
424  _mm512_store_ps(out_ptr_token + j + 32, o2);
425  _mm512_store_ps(out_ptr_token + j + 48, o3);
426  }
427  for (; j < d_model; ++j) {
428  float normed = (in_ptr_token[j] - mean) * inv_std;
429  out_ptr_token[j] = normed * gamma[j] + beta[j];
430  }
431  }
432 }
433 #elif defined(__AVX2__) || defined(__AVX__)
434 // AVX/AVX2 unrolled slice kernel (8-float vectors).
435 static void layernorm_forward_unrolled_slice_avx256(const float *__restrict input_slice_base,
436  const float *__restrict gamma,
437  const float *__restrict beta,
438  float *__restrict output_slice_base,
439  float *__restrict mean_cache_slice,
440  float *__restrict rstd_cache_slice,
441  int num_tokens_in_slice,
442  int d_model,
443  float eps)
444 {
445  for (int t = 0; t < num_tokens_in_slice; ++t) {
446  const float *in_ptr_token = input_slice_base + t * d_model;
447  float *out_ptr_token = output_slice_base + t * d_model;
448 
449  __m256 acc0 = _mm256_setzero_ps();
450  __m256 acc1 = _mm256_setzero_ps();
451  __m256 acc2 = _mm256_setzero_ps();
452  __m256 acc3 = _mm256_setzero_ps();
453 
454  int j = 0;
455  int unroll_factor_floats = 32;
456 
457  for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
458  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
459 
460  __m256 v0 = _mm256_load_ps(in_ptr_token + j);
461  __m256 v1 = _mm256_load_ps(in_ptr_token + j + 8);
462  __m256 v2 = _mm256_load_ps(in_ptr_token + j + 16);
463  __m256 v3 = _mm256_load_ps(in_ptr_token + j + 24);
464 
465  acc0 = _mm256_add_ps(acc0, v0);
466  acc1 = _mm256_add_ps(acc1, v1);
467  acc2 = _mm256_add_ps(acc2, v2);
468  acc3 = _mm256_add_ps(acc3, v3);
469  }
470  __m256 acc_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1),
471  _mm256_add_ps(acc2, acc3));
472  float mean = hsum256_ps(acc_sum);
473 
474  for (; j < d_model; ++j) {
475  mean += in_ptr_token[j];
476  }
477  mean /= (float)d_model;
478  __m256 mean_vec = _mm256_set1_ps(mean);
479 
480  acc0 = _mm256_setzero_ps();
481  acc1 = _mm256_setzero_ps();
482  acc2 = _mm256_setzero_ps();
483  acc3 = _mm256_setzero_ps();
484 
485  j = 0;
486  for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
487  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
488 
489  __m256 v0 = _mm256_load_ps(in_ptr_token + j);
490  __m256 v1 = _mm256_load_ps(in_ptr_token + j + 8);
491  __m256 v2 = _mm256_load_ps(in_ptr_token + j + 16);
492  __m256 v3 = _mm256_load_ps(in_ptr_token + j + 24);
493 
494  __m256 d0 = _mm256_sub_ps(v0, mean_vec);
495  __m256 d1 = _mm256_sub_ps(v1, mean_vec);
496  __m256 d2 = _mm256_sub_ps(v2, mean_vec);
497  __m256 d3 = _mm256_sub_ps(v3, mean_vec);
498 
499 #if defined(__FMA__)
500  acc0 = _mm256_fmadd_ps(d0, d0, acc0);
501  acc1 = _mm256_fmadd_ps(d1, d1, acc1);
502  acc2 = _mm256_fmadd_ps(d2, d2, acc2);
503  acc3 = _mm256_fmadd_ps(d3, d3, acc3);
504 #else
505  acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(d0, d0));
506  acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(d1, d1));
507  acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(d2, d2));
508  acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(d3, d3));
509 #endif
510  }
511  acc_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1),
512  _mm256_add_ps(acc2, acc3));
513  float var = hsum256_ps(acc_sum);
514 
515  for (; j < d_model; ++j) {
516  float diff = in_ptr_token[j] - mean;
517  var += diff * diff;
518  }
519  var = var / (float)d_model + eps;
520  double var_double = (double)var;
521  float inv_std = (float)(1.0 / sqrt(var_double));
522  __m256 inv_std_vec = _mm256_set1_ps(inv_std);
523 
524  if (mean_cache_slice) {
525  mean_cache_slice[t] = mean;
526  }
527  if (rstd_cache_slice) {
528  rstd_cache_slice[t] = inv_std;
529  }
530 
531  j = 0;
532  for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
533  _mm_prefetch((const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
534  _mm_prefetch((const char *)(gamma + j + 128), _MM_HINT_T0);
535  _mm_prefetch((const char *)(beta + j + 128), _MM_HINT_T0);
536 
537  __m256 v0 = _mm256_load_ps(in_ptr_token + j);
538  __m256 v1 = _mm256_load_ps(in_ptr_token + j + 8);
539  __m256 v2 = _mm256_load_ps(in_ptr_token + j + 16);
540  __m256 v3 = _mm256_load_ps(in_ptr_token + j + 24);
541 
542  __m256 g0 = _mm256_load_ps(gamma + j);
543  __m256 g1 = _mm256_load_ps(gamma + j + 8);
544  __m256 g2 = _mm256_load_ps(gamma + j + 16);
545  __m256 g3 = _mm256_load_ps(gamma + j + 24);
546 
547  __m256 b0 = _mm256_load_ps(beta + j);
548  __m256 b1 = _mm256_load_ps(beta + j + 8);
549  __m256 b2 = _mm256_load_ps(beta + j + 16);
550  __m256 b3 = _mm256_load_ps(beta + j + 24);
551 
552  __m256 n0 = _mm256_mul_ps(_mm256_sub_ps(v0, mean_vec), inv_std_vec);
553  __m256 n1 = _mm256_mul_ps(_mm256_sub_ps(v1, mean_vec), inv_std_vec);
554  __m256 n2 = _mm256_mul_ps(_mm256_sub_ps(v2, mean_vec), inv_std_vec);
555  __m256 n3 = _mm256_mul_ps(_mm256_sub_ps(v3, mean_vec), inv_std_vec);
556 
557 #if defined(__FMA__)
558  __m256 o0 = _mm256_fmadd_ps(n0, g0, b0);
559  __m256 o1 = _mm256_fmadd_ps(n1, g1, b1);
560  __m256 o2 = _mm256_fmadd_ps(n2, g2, b2);
561  __m256 o3 = _mm256_fmadd_ps(n3, g3, b3);
562 #else
563  __m256 o0 = _mm256_add_ps(_mm256_mul_ps(n0, g0), b0);
564  __m256 o1 = _mm256_add_ps(_mm256_mul_ps(n1, g1), b1);
565  __m256 o2 = _mm256_add_ps(_mm256_mul_ps(n2, g2), b2);
566  __m256 o3 = _mm256_add_ps(_mm256_mul_ps(n3, g3), b3);
567 #endif
568 
569  _mm256_store_ps(out_ptr_token + j, o0);
570  _mm256_store_ps(out_ptr_token + j + 8, o1);
571  _mm256_store_ps(out_ptr_token + j + 16, o2);
572  _mm256_store_ps(out_ptr_token + j + 24, o3);
573  }
574  for (; j < d_model; ++j) {
575  float normed = (in_ptr_token[j] - mean) * inv_std;
576  out_ptr_token[j] = normed * gamma[j] + beta[j];
577  }
578  }
579 }
580 #else
581 // Scalar fallback when AVX-512 is unavailable.
582 static void layernorm_forward_unrolled_slice_scalar(const float *__restrict input_slice_base,
583  const float *__restrict gamma,
584  const float *__restrict beta,
585  float *__restrict output_slice_base,
586  float *__restrict mean_cache_slice,
587  float *__restrict rstd_cache_slice,
588  int num_tokens_in_slice,
589  int d_model,
590  float eps)
591 {
592  layernorm_naive_serial_matched_precision(input_slice_base, gamma, beta,
593  output_slice_base, mean_cache_slice, rstd_cache_slice,
594  num_tokens_in_slice, d_model, eps);
595 }
596 #endif
597 
598 void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base,
599  const float *__restrict gamma,
600  const float *__restrict beta,
601  float *__restrict output_slice_base,
602  float *__restrict mean_cache_slice,
603  float *__restrict rstd_cache_slice,
604  int num_tokens_in_slice,
605  int d_model,
606  float eps)
607 {
608 #if defined(__AVX512F__)
609  layernorm_forward_unrolled_slice_avx512(input_slice_base, gamma, beta,
610  output_slice_base, mean_cache_slice, rstd_cache_slice,
611  num_tokens_in_slice, d_model, eps);
612 #elif defined(__AVX2__) || defined(__AVX__)
613  layernorm_forward_unrolled_slice_avx256(input_slice_base, gamma, beta,
614  output_slice_base, mean_cache_slice, rstd_cache_slice,
615  num_tokens_in_slice, d_model, eps);
616 #else
617  layernorm_forward_unrolled_slice_scalar(input_slice_base, gamma, beta,
618  output_slice_base, mean_cache_slice, rstd_cache_slice,
619  num_tokens_in_slice, d_model, eps);
620 #endif
621 }
622 
623 // Precision-matched naive LayerNorm used for benchmarking, copied from C-Transformer.
625  const float *gamma,
626  const float *beta,
627  float *output,
628  float *mean_cache,
629  float *rstd_cache,
630  int tokens, int d_model, float eps)
631 {
632  for (int t = 0; t < tokens; ++t) {
633  const float *in_ptr = input + t * d_model;
634  float *out_ptr = output + t * d_model;
635 
636  float sum_val = 0.0f;
637  for (int i = 0; i < d_model; ++i) {
638  sum_val += in_ptr[i];
639  }
640  float mean = sum_val / (float)d_model;
641 
642  float sum_sq_diff = 0.0f;
643  for (int i = 0; i < d_model; ++i) {
644  float diff = in_ptr[i] - mean;
645  sum_sq_diff += diff * diff;
646  }
647  float variance = sum_sq_diff / (float)d_model + eps;
648 
649  double var_double = (double)variance;
650  float inv_std = (float)(1.0 / sqrt(var_double));
651 
652  for (int i = 0; i < d_model; ++i) {
653  float normalized_val = (in_ptr[i] - mean) * inv_std;
654  out_ptr[i] = normalized_val * gamma[i] + beta[i];
655  }
656 
657  if (mean_cache) {
658  mean_cache[t] = mean;
659  }
660  if (rstd_cache) {
661  rstd_cache[t] = inv_std;
662  }
663  }
664 }
665 
666 // LayerNorm backward kernel (model-agnostic), adapted from C-Transformer's
667 // backward_layernorm. Computes gradients w.r.t. input, gamma, and beta.
668 void layernorm_backward_kernel(const float *d_output, // [T×aligned_D]
669  const float *input, // [T×aligned_D]
670  const float *gamma, // [D]
671  const float *mean, // [T]
672  const float *rstd, // [T]
673  float *d_input, // [T×aligned_D]
674  float *d_gamma, // [D] (accumulated)
675  float *d_beta, // [D] (accumulated)
676  int tokens, int d_model, int aligned_embed_dim)
677 {
678  int T = tokens;
679  int D = d_model;
680  int aligned_D = aligned_embed_dim;
681 
682  // Per-token input gradients
683  for (int t = 0; t < T; ++t) {
684  float mean_t = mean[t];
685  float rstd_t = rstd[t];
686 
687  float d_y_gamma_sum = 0.0f;
688  float d_y_gamma_xhat_sum = 0.0f;
689 
690  // First pass: compute sums
691  for (int d = 0; d < D; ++d) {
692  float x = input[t * aligned_D + d];
693  float x_hat = (x - mean_t) * rstd_t;
694  float d_y = d_output[t * aligned_D + d];
695  float d_y_gamma = d_y * gamma[d];
696 
697  d_y_gamma_sum += d_y_gamma;
698  d_y_gamma_xhat_sum += d_y_gamma * x_hat;
699  }
700 
701  // Second pass: compute input gradients
702  float scale = rstd_t / (float)D;
703  for (int d = 0; d < D; ++d) {
704  float x = input[t * aligned_D + d];
705  float x_hat = (x - mean_t) * rstd_t;
706  float d_y = d_output[t * aligned_D + d];
707 
708  d_input[t * aligned_D + d] =
709  scale * ((float)D * d_y * gamma[d] - d_y_gamma_sum - x_hat * d_y_gamma_xhat_sum);
710  }
711 
712  // Zero padding for aligned dimension beyond D
713  for (int d = D; d < aligned_D; ++d) {
714  d_input[t * aligned_D + d] = 0.0f;
715  }
716  }
717 
718  // Parameter gradients (gamma, beta)
719  for (int d = 0; d < D; ++d) {
720  float gamma_grad = 0.0f;
721  float beta_grad = 0.0f;
722 
723  for (int t = 0; t < T; ++t) {
724  float x = input[t * aligned_D + d];
725  float x_hat = (x - mean[t]) * rstd[t];
726  float d_y = d_output[t * aligned_D + d];
727 
728  gamma_grad += d_y * x_hat;
729  beta_grad += d_y;
730  }
731 
732  d_gamma[d] += gamma_grad;
733  d_beta[d] += beta_grad;
734  }
735 }
void layernorm_naive_serial_matched_precision(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, float eps)
static void zero_layernorm_padding(float *out_ptr, int d_model, int aligned_embed_dim)
void layernorm_naive_serial(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
static void layernorm_forward_unrolled_slice_scalar(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)