17 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
18 #include <immintrin.h>
24 int aligned_embed_dim)
26 for (
int idx = d_model; idx < aligned_embed_dim; ++idx) {
37 int tokens,
int d_model,
float eps);
39 #if defined(__AVX2__) || defined(__AVX__)
40 static inline float hsum256_ps(__m256 v)
42 __m128 low = _mm256_castps256_ps128(v);
43 __m128 high = _mm256_extractf128_ps(v, 1);
44 __m128 sum = _mm_add_ps(low, high);
45 sum = _mm_hadd_ps(sum, sum);
46 sum = _mm_hadd_ps(sum, sum);
47 return _mm_cvtss_f32(sum);
57 int tokens,
int d_model,
int aligned_embed_dim,
60 for (
int t = 0; t < tokens; ++t) {
61 const float *in_ptr = input + t * aligned_embed_dim;
62 float *out_ptr = output + t * aligned_embed_dim;
65 for (
int i = 0; i < d_model; ++i) {
68 float mean = sum_val / (float)d_model;
70 float sum_sq_diff = 0.0f;
71 for (
int i = 0; i < d_model; ++i) {
72 float diff = in_ptr[i] - mean;
73 sum_sq_diff += diff * diff;
75 float variance = sum_sq_diff / (float)d_model + eps;
77 double var_double = (double)variance;
78 float inv_std = (float)(1.0 / sqrt(var_double));
80 for (
int i = 0; i < d_model; ++i) {
81 float normalized_val = (in_ptr[i] - mean) * inv_std;
82 out_ptr[i] = normalized_val * gamma[i] + beta[i];
89 rstd_cache[t] = inv_std;
92 if (aligned_embed_dim > d_model) {
94 for (
int i = d_model; i < aligned_embed_dim; ++i) {
101 #if defined(__AVX512F__)
103 static void layernorm_forward_rolled_slice_avx512(
const float *__restrict input_slice_base,
104 const float *__restrict gamma,
105 const float *__restrict beta,
106 float *__restrict output_slice_base,
107 float *__restrict mean_cache_slice,
108 float *__restrict rstd_cache_slice,
109 int num_tokens_in_slice,
111 int aligned_embed_dim,
114 for (
int t = 0; t < num_tokens_in_slice; ++t) {
115 const float *in_ptr_token = input_slice_base + t * aligned_embed_dim;
116 float *out_ptr_token = output_slice_base + t * aligned_embed_dim;
118 __m512 acc_sum_vec = _mm512_setzero_ps();
120 for (; j <= d_model - 16; j += 16) {
121 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
122 __m512 v = _mm512_load_ps(in_ptr_token + j);
123 acc_sum_vec = _mm512_add_ps(acc_sum_vec, v);
125 float mean = _mm512_reduce_add_ps(acc_sum_vec);
126 for (; j < d_model; ++j) {
127 mean += in_ptr_token[j];
129 mean /= (float)d_model;
130 __m512 mean_vec = _mm512_set1_ps(mean);
132 __m512 acc_var_vec = _mm512_setzero_ps();
134 for (; j <= d_model - 16; j += 16) {
135 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
136 __m512 v = _mm512_load_ps(in_ptr_token + j);
137 __m512 diff = _mm512_sub_ps(v, mean_vec);
138 acc_var_vec = _mm512_fmadd_ps(diff, diff, acc_var_vec);
140 float var = _mm512_reduce_add_ps(acc_var_vec);
141 for (; j < d_model; ++j) {
142 float diff = in_ptr_token[j] - mean;
145 var = var / (float)d_model + eps;
146 double var_double = (double)var;
147 float inv_std = (float)(1.0 / sqrt(var_double));
148 __m512 inv_std_vec = _mm512_set1_ps(inv_std);
150 if (mean_cache_slice) {
151 mean_cache_slice[t] = mean;
153 if (rstd_cache_slice) {
154 rstd_cache_slice[t] = inv_std;
158 for (; j <= d_model - 16; j += 16) {
159 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
160 _mm_prefetch((
const char *)(gamma + j + 128), _MM_HINT_T0);
161 _mm_prefetch((
const char *)(beta + j + 128), _MM_HINT_T0);
163 __m512 v = _mm512_load_ps(in_ptr_token + j);
164 __m512 g = _mm512_load_ps(gamma + j);
165 __m512 b = _mm512_load_ps(beta + j);
167 __m512 n = _mm512_mul_ps(_mm512_sub_ps(v, mean_vec), inv_std_vec);
168 __m512 o = _mm512_fmadd_ps(n, g, b);
170 _mm512_store_ps(out_ptr_token + j, o);
172 for (; j < d_model; ++j) {
173 float normed = (in_ptr_token[j] - mean) * inv_std;
174 out_ptr_token[j] = normed * gamma[j] + beta[j];
177 if (aligned_embed_dim > d_model) {
183 #elif defined(__AVX2__) || defined(__AVX__)
185 static void layernorm_forward_rolled_slice_avx256(
const float *__restrict input_slice_base,
186 const float *__restrict gamma,
187 const float *__restrict beta,
188 float *__restrict output_slice_base,
189 float *__restrict mean_cache_slice,
190 float *__restrict rstd_cache_slice,
191 int num_tokens_in_slice,
193 int aligned_embed_dim,
196 for (
int t = 0; t < num_tokens_in_slice; ++t) {
197 const float *in_ptr_token = input_slice_base + t * aligned_embed_dim;
198 float *out_ptr_token = output_slice_base + t * aligned_embed_dim;
200 __m256 acc_sum_vec = _mm256_setzero_ps();
202 for (; j <= d_model - 8; j += 8) {
203 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
204 __m256 v = _mm256_load_ps(in_ptr_token + j);
205 acc_sum_vec = _mm256_add_ps(acc_sum_vec, v);
207 float mean = hsum256_ps(acc_sum_vec);
208 for (; j < d_model; ++j) {
209 mean += in_ptr_token[j];
211 mean /= (float)d_model;
212 __m256 mean_vec = _mm256_set1_ps(mean);
214 __m256 acc_var_vec = _mm256_setzero_ps();
216 for (; j <= d_model - 8; j += 8) {
217 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
218 __m256 v = _mm256_load_ps(in_ptr_token + j);
219 __m256 diff = _mm256_sub_ps(v, mean_vec);
221 acc_var_vec = _mm256_fmadd_ps(diff, diff, acc_var_vec);
223 acc_var_vec = _mm256_add_ps(acc_var_vec, _mm256_mul_ps(diff, diff));
226 float var = hsum256_ps(acc_var_vec);
227 for (; j < d_model; ++j) {
228 float diff = in_ptr_token[j] - mean;
231 var = var / (float)d_model + eps;
232 double var_double = (double)var;
233 float inv_std = (float)(1.0 / sqrt(var_double));
234 __m256 inv_std_vec = _mm256_set1_ps(inv_std);
236 if (mean_cache_slice) {
237 mean_cache_slice[t] = mean;
239 if (rstd_cache_slice) {
240 rstd_cache_slice[t] = inv_std;
244 for (; j <= d_model - 8; j += 8) {
245 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
246 _mm_prefetch((
const char *)(gamma + j + 128), _MM_HINT_T0);
247 _mm_prefetch((
const char *)(beta + j + 128), _MM_HINT_T0);
249 __m256 v = _mm256_load_ps(in_ptr_token + j);
250 __m256 g = _mm256_load_ps(gamma + j);
251 __m256 b = _mm256_load_ps(beta + j);
253 __m256 n = _mm256_mul_ps(_mm256_sub_ps(v, mean_vec), inv_std_vec);
255 __m256 o = _mm256_fmadd_ps(n, g, b);
257 __m256 o = _mm256_add_ps(_mm256_mul_ps(n, g), b);
260 _mm256_store_ps(out_ptr_token + j, o);
262 for (; j < d_model; ++j) {
263 float normed = (in_ptr_token[j] - mean) * inv_std;
264 out_ptr_token[j] = normed * gamma[j] + beta[j];
267 if (aligned_embed_dim > d_model) {
275 const float *__restrict gamma,
276 const float *__restrict beta,
277 float *__restrict output_slice_base,
278 float *__restrict mean_cache_slice,
279 float *__restrict rstd_cache_slice,
280 int num_tokens_in_slice,
282 int aligned_embed_dim,
285 #if defined(__AVX512F__)
286 layernorm_forward_rolled_slice_avx512(input_slice_base, gamma, beta,
287 output_slice_base, mean_cache_slice, rstd_cache_slice,
288 num_tokens_in_slice, d_model, aligned_embed_dim, eps);
289 #elif defined(__AVX2__) || defined(__AVX__)
290 layernorm_forward_rolled_slice_avx256(input_slice_base, gamma, beta,
291 output_slice_base, mean_cache_slice, rstd_cache_slice,
292 num_tokens_in_slice, d_model, aligned_embed_dim, eps);
295 output_slice_base, mean_cache_slice, rstd_cache_slice,
296 num_tokens_in_slice, d_model, aligned_embed_dim, eps);
300 #if defined(__AVX512F__)
302 static void layernorm_forward_unrolled_slice_avx512(
const float *__restrict input_slice_base,
303 const float *__restrict gamma,
304 const float *__restrict beta,
305 float *__restrict output_slice_base,
306 float *__restrict mean_cache_slice,
307 float *__restrict rstd_cache_slice,
308 int num_tokens_in_slice,
312 for (
int t = 0; t < num_tokens_in_slice; ++t) {
313 const float *in_ptr_token = input_slice_base + t * d_model;
314 float *out_ptr_token = output_slice_base + t * d_model;
316 __m512 acc0 = _mm512_setzero_ps();
317 __m512 acc1 = _mm512_setzero_ps();
318 __m512 acc2 = _mm512_setzero_ps();
319 __m512 acc3 = _mm512_setzero_ps();
322 int unroll_factor_floats = 64;
324 for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
325 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
327 __m512 v0 = _mm512_load_ps(in_ptr_token + j);
328 __m512 v1 = _mm512_load_ps(in_ptr_token + j + 16);
329 __m512 v2 = _mm512_load_ps(in_ptr_token + j + 32);
330 __m512 v3 = _mm512_load_ps(in_ptr_token + j + 48);
332 acc0 = _mm512_add_ps(acc0, v0);
333 acc1 = _mm512_add_ps(acc1, v1);
334 acc2 = _mm512_add_ps(acc2, v2);
335 acc3 = _mm512_add_ps(acc3, v3);
337 __m512 acc_sum = _mm512_add_ps(_mm512_add_ps(acc0, acc1),
338 _mm512_add_ps(acc2, acc3));
339 float mean = _mm512_reduce_add_ps(acc_sum);
341 for (; j < d_model; ++j) {
342 mean += in_ptr_token[j];
344 mean /= (float)d_model;
345 __m512 mean_vec = _mm512_set1_ps(mean);
347 acc0 = _mm512_setzero_ps();
348 acc1 = _mm512_setzero_ps();
349 acc2 = _mm512_setzero_ps();
350 acc3 = _mm512_setzero_ps();
353 for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
354 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
356 __m512 v0 = _mm512_load_ps(in_ptr_token + j);
357 __m512 v1 = _mm512_load_ps(in_ptr_token + j + 16);
358 __m512 v2 = _mm512_load_ps(in_ptr_token + j + 32);
359 __m512 v3 = _mm512_load_ps(in_ptr_token + j + 48);
361 __m512 d0 = _mm512_sub_ps(v0, mean_vec);
362 __m512 d1 = _mm512_sub_ps(v1, mean_vec);
363 __m512 d2 = _mm512_sub_ps(v2, mean_vec);
364 __m512 d3 = _mm512_sub_ps(v3, mean_vec);
366 acc0 = _mm512_fmadd_ps(d0, d0, acc0);
367 acc1 = _mm512_fmadd_ps(d1, d1, acc1);
368 acc2 = _mm512_fmadd_ps(d2, d2, acc2);
369 acc3 = _mm512_fmadd_ps(d3, d3, acc3);
371 acc_sum = _mm512_add_ps(_mm512_add_ps(acc0, acc1),
372 _mm512_add_ps(acc2, acc3));
373 float var = _mm512_reduce_add_ps(acc_sum);
375 for (; j < d_model; ++j) {
376 float diff = in_ptr_token[j] - mean;
379 var = var / (float)d_model + eps;
380 double var_double = (double)var;
381 float inv_std = (float)(1.0 / sqrt(var_double));
382 __m512 inv_std_vec = _mm512_set1_ps(inv_std);
384 if (mean_cache_slice) {
385 mean_cache_slice[t] = mean;
387 if (rstd_cache_slice) {
388 rstd_cache_slice[t] = inv_std;
392 for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
393 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
394 _mm_prefetch((
const char *)(gamma + j + 128), _MM_HINT_T0);
395 _mm_prefetch((
const char *)(beta + j + 128), _MM_HINT_T0);
397 __m512 v0 = _mm512_load_ps(in_ptr_token + j);
398 __m512 v1 = _mm512_load_ps(in_ptr_token + j + 16);
399 __m512 v2 = _mm512_load_ps(in_ptr_token + j + 32);
400 __m512 v3 = _mm512_load_ps(in_ptr_token + j + 48);
402 __m512 g0 = _mm512_load_ps(gamma + j);
403 __m512 g1 = _mm512_load_ps(gamma + j + 16);
404 __m512 g2 = _mm512_load_ps(gamma + j + 32);
405 __m512 g3 = _mm512_load_ps(gamma + j + 48);
407 __m512 b0 = _mm512_load_ps(beta + j);
408 __m512 b1 = _mm512_load_ps(beta + j + 16);
409 __m512 b2 = _mm512_load_ps(beta + j + 32);
410 __m512 b3 = _mm512_load_ps(beta + j + 48);
412 __m512 n0 = _mm512_mul_ps(_mm512_sub_ps(v0, mean_vec), inv_std_vec);
413 __m512 n1 = _mm512_mul_ps(_mm512_sub_ps(v1, mean_vec), inv_std_vec);
414 __m512 n2 = _mm512_mul_ps(_mm512_sub_ps(v2, mean_vec), inv_std_vec);
415 __m512 n3 = _mm512_mul_ps(_mm512_sub_ps(v3, mean_vec), inv_std_vec);
417 __m512 o0 = _mm512_fmadd_ps(n0, g0, b0);
418 __m512 o1 = _mm512_fmadd_ps(n1, g1, b1);
419 __m512 o2 = _mm512_fmadd_ps(n2, g2, b2);
420 __m512 o3 = _mm512_fmadd_ps(n3, g3, b3);
422 _mm512_store_ps(out_ptr_token + j, o0);
423 _mm512_store_ps(out_ptr_token + j + 16, o1);
424 _mm512_store_ps(out_ptr_token + j + 32, o2);
425 _mm512_store_ps(out_ptr_token + j + 48, o3);
427 for (; j < d_model; ++j) {
428 float normed = (in_ptr_token[j] - mean) * inv_std;
429 out_ptr_token[j] = normed * gamma[j] + beta[j];
433 #elif defined(__AVX2__) || defined(__AVX__)
435 static void layernorm_forward_unrolled_slice_avx256(
const float *__restrict input_slice_base,
436 const float *__restrict gamma,
437 const float *__restrict beta,
438 float *__restrict output_slice_base,
439 float *__restrict mean_cache_slice,
440 float *__restrict rstd_cache_slice,
441 int num_tokens_in_slice,
445 for (
int t = 0; t < num_tokens_in_slice; ++t) {
446 const float *in_ptr_token = input_slice_base + t * d_model;
447 float *out_ptr_token = output_slice_base + t * d_model;
449 __m256 acc0 = _mm256_setzero_ps();
450 __m256 acc1 = _mm256_setzero_ps();
451 __m256 acc2 = _mm256_setzero_ps();
452 __m256 acc3 = _mm256_setzero_ps();
455 int unroll_factor_floats = 32;
457 for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
458 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
460 __m256 v0 = _mm256_load_ps(in_ptr_token + j);
461 __m256 v1 = _mm256_load_ps(in_ptr_token + j + 8);
462 __m256 v2 = _mm256_load_ps(in_ptr_token + j + 16);
463 __m256 v3 = _mm256_load_ps(in_ptr_token + j + 24);
465 acc0 = _mm256_add_ps(acc0, v0);
466 acc1 = _mm256_add_ps(acc1, v1);
467 acc2 = _mm256_add_ps(acc2, v2);
468 acc3 = _mm256_add_ps(acc3, v3);
470 __m256 acc_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1),
471 _mm256_add_ps(acc2, acc3));
472 float mean = hsum256_ps(acc_sum);
474 for (; j < d_model; ++j) {
475 mean += in_ptr_token[j];
477 mean /= (float)d_model;
478 __m256 mean_vec = _mm256_set1_ps(mean);
480 acc0 = _mm256_setzero_ps();
481 acc1 = _mm256_setzero_ps();
482 acc2 = _mm256_setzero_ps();
483 acc3 = _mm256_setzero_ps();
486 for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
487 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
489 __m256 v0 = _mm256_load_ps(in_ptr_token + j);
490 __m256 v1 = _mm256_load_ps(in_ptr_token + j + 8);
491 __m256 v2 = _mm256_load_ps(in_ptr_token + j + 16);
492 __m256 v3 = _mm256_load_ps(in_ptr_token + j + 24);
494 __m256 d0 = _mm256_sub_ps(v0, mean_vec);
495 __m256 d1 = _mm256_sub_ps(v1, mean_vec);
496 __m256 d2 = _mm256_sub_ps(v2, mean_vec);
497 __m256 d3 = _mm256_sub_ps(v3, mean_vec);
500 acc0 = _mm256_fmadd_ps(d0, d0, acc0);
501 acc1 = _mm256_fmadd_ps(d1, d1, acc1);
502 acc2 = _mm256_fmadd_ps(d2, d2, acc2);
503 acc3 = _mm256_fmadd_ps(d3, d3, acc3);
505 acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(d0, d0));
506 acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(d1, d1));
507 acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(d2, d2));
508 acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(d3, d3));
511 acc_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1),
512 _mm256_add_ps(acc2, acc3));
513 float var = hsum256_ps(acc_sum);
515 for (; j < d_model; ++j) {
516 float diff = in_ptr_token[j] - mean;
519 var = var / (float)d_model + eps;
520 double var_double = (double)var;
521 float inv_std = (float)(1.0 / sqrt(var_double));
522 __m256 inv_std_vec = _mm256_set1_ps(inv_std);
524 if (mean_cache_slice) {
525 mean_cache_slice[t] = mean;
527 if (rstd_cache_slice) {
528 rstd_cache_slice[t] = inv_std;
532 for (; j <= d_model - unroll_factor_floats; j += unroll_factor_floats) {
533 _mm_prefetch((
const char *)(in_ptr_token + j + 128), _MM_HINT_T0);
534 _mm_prefetch((
const char *)(gamma + j + 128), _MM_HINT_T0);
535 _mm_prefetch((
const char *)(beta + j + 128), _MM_HINT_T0);
537 __m256 v0 = _mm256_load_ps(in_ptr_token + j);
538 __m256 v1 = _mm256_load_ps(in_ptr_token + j + 8);
539 __m256 v2 = _mm256_load_ps(in_ptr_token + j + 16);
540 __m256 v3 = _mm256_load_ps(in_ptr_token + j + 24);
542 __m256 g0 = _mm256_load_ps(gamma + j);
543 __m256 g1 = _mm256_load_ps(gamma + j + 8);
544 __m256 g2 = _mm256_load_ps(gamma + j + 16);
545 __m256 g3 = _mm256_load_ps(gamma + j + 24);
547 __m256 b0 = _mm256_load_ps(beta + j);
548 __m256 b1 = _mm256_load_ps(beta + j + 8);
549 __m256 b2 = _mm256_load_ps(beta + j + 16);
550 __m256 b3 = _mm256_load_ps(beta + j + 24);
552 __m256 n0 = _mm256_mul_ps(_mm256_sub_ps(v0, mean_vec), inv_std_vec);
553 __m256 n1 = _mm256_mul_ps(_mm256_sub_ps(v1, mean_vec), inv_std_vec);
554 __m256 n2 = _mm256_mul_ps(_mm256_sub_ps(v2, mean_vec), inv_std_vec);
555 __m256 n3 = _mm256_mul_ps(_mm256_sub_ps(v3, mean_vec), inv_std_vec);
558 __m256 o0 = _mm256_fmadd_ps(n0, g0, b0);
559 __m256 o1 = _mm256_fmadd_ps(n1, g1, b1);
560 __m256 o2 = _mm256_fmadd_ps(n2, g2, b2);
561 __m256 o3 = _mm256_fmadd_ps(n3, g3, b3);
563 __m256 o0 = _mm256_add_ps(_mm256_mul_ps(n0, g0), b0);
564 __m256 o1 = _mm256_add_ps(_mm256_mul_ps(n1, g1), b1);
565 __m256 o2 = _mm256_add_ps(_mm256_mul_ps(n2, g2), b2);
566 __m256 o3 = _mm256_add_ps(_mm256_mul_ps(n3, g3), b3);
569 _mm256_store_ps(out_ptr_token + j, o0);
570 _mm256_store_ps(out_ptr_token + j + 8, o1);
571 _mm256_store_ps(out_ptr_token + j + 16, o2);
572 _mm256_store_ps(out_ptr_token + j + 24, o3);
574 for (; j < d_model; ++j) {
575 float normed = (in_ptr_token[j] - mean) * inv_std;
576 out_ptr_token[j] = normed * gamma[j] + beta[j];
583 const float *__restrict gamma,
584 const float *__restrict beta,
585 float *__restrict output_slice_base,
586 float *__restrict mean_cache_slice,
587 float *__restrict rstd_cache_slice,
588 int num_tokens_in_slice,
593 output_slice_base, mean_cache_slice, rstd_cache_slice,
594 num_tokens_in_slice, d_model, eps);
599 const float *__restrict gamma,
600 const float *__restrict beta,
601 float *__restrict output_slice_base,
602 float *__restrict mean_cache_slice,
603 float *__restrict rstd_cache_slice,
604 int num_tokens_in_slice,
608 #if defined(__AVX512F__)
609 layernorm_forward_unrolled_slice_avx512(input_slice_base, gamma, beta,
610 output_slice_base, mean_cache_slice, rstd_cache_slice,
611 num_tokens_in_slice, d_model, eps);
612 #elif defined(__AVX2__) || defined(__AVX__)
613 layernorm_forward_unrolled_slice_avx256(input_slice_base, gamma, beta,
614 output_slice_base, mean_cache_slice, rstd_cache_slice,
615 num_tokens_in_slice, d_model, eps);
618 output_slice_base, mean_cache_slice, rstd_cache_slice,
619 num_tokens_in_slice, d_model, eps);
630 int tokens,
int d_model,
float eps)
632 for (
int t = 0; t < tokens; ++t) {
633 const float *in_ptr = input + t * d_model;
634 float *out_ptr = output + t * d_model;
636 float sum_val = 0.0f;
637 for (
int i = 0; i < d_model; ++i) {
638 sum_val += in_ptr[i];
640 float mean = sum_val / (float)d_model;
642 float sum_sq_diff = 0.0f;
643 for (
int i = 0; i < d_model; ++i) {
644 float diff = in_ptr[i] - mean;
645 sum_sq_diff += diff * diff;
647 float variance = sum_sq_diff / (float)d_model + eps;
649 double var_double = (double)variance;
650 float inv_std = (float)(1.0 / sqrt(var_double));
652 for (
int i = 0; i < d_model; ++i) {
653 float normalized_val = (in_ptr[i] - mean) * inv_std;
654 out_ptr[i] = normalized_val * gamma[i] + beta[i];
658 mean_cache[t] = mean;
661 rstd_cache[t] = inv_std;
676 int tokens,
int d_model,
int aligned_embed_dim)
680 int aligned_D = aligned_embed_dim;
683 for (
int t = 0; t < T; ++t) {
684 float mean_t = mean[t];
685 float rstd_t = rstd[t];
687 float d_y_gamma_sum = 0.0f;
688 float d_y_gamma_xhat_sum = 0.0f;
691 for (
int d = 0; d < D; ++d) {
692 float x = input[t * aligned_D + d];
693 float x_hat = (x - mean_t) * rstd_t;
694 float d_y = d_output[t * aligned_D + d];
695 float d_y_gamma = d_y * gamma[d];
697 d_y_gamma_sum += d_y_gamma;
698 d_y_gamma_xhat_sum += d_y_gamma * x_hat;
702 float scale = rstd_t / (float)D;
703 for (
int d = 0; d < D; ++d) {
704 float x = input[t * aligned_D + d];
705 float x_hat = (x - mean_t) * rstd_t;
706 float d_y = d_output[t * aligned_D + d];
708 d_input[t * aligned_D + d] =
709 scale * ((float)D * d_y * gamma[d] - d_y_gamma_sum - x_hat * d_y_gamma_xhat_sum);
713 for (
int d = D; d < aligned_D; ++d) {
714 d_input[t * aligned_D + d] = 0.0f;
719 for (
int d = 0; d < D; ++d) {
720 float gamma_grad = 0.0f;
721 float beta_grad = 0.0f;
723 for (
int t = 0; t < T; ++t) {
724 float x = input[t * aligned_D + d];
725 float x_hat = (x - mean[t]) * rstd[t];
726 float d_y = d_output[t * aligned_D + d];
728 gamma_grad += d_y * x_hat;
732 d_gamma[d] += gamma_grad;
733 d_beta[d] += beta_grad;
void layernorm_naive_serial_matched_precision(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, float eps)
static void zero_layernorm_padding(float *out_ptr, int d_model, int aligned_embed_dim)
void layernorm_naive_serial(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
static void layernorm_forward_unrolled_slice_scalar(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)