23 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
24 #include <immintrin.h>
30 #if defined(__AVX512F__)
31 static inline __m512 exp512_fast(__m512 x) {
33 x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
34 x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
36 const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
37 const __m512 c1 = _mm512_set1_ps(0.693359375f);
38 const __m512 c2 = _mm512_set1_ps(-2.12194440e-4f);
40 __m512 t = _mm512_mul_ps(x, log2e);
41 __m512 ti = _mm512_roundscale_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
43 __m512 rx = _mm512_sub_ps(x, _mm512_mul_ps(ti, c1));
44 rx = _mm512_sub_ps(rx, _mm512_mul_ps(ti, c2));
47 const __m512 p0 = _mm512_set1_ps(1.0f);
48 const __m512 p1 = _mm512_set1_ps(0.6931471805599453f);
49 const __m512 p2 = _mm512_set1_ps(0.24022650695910071f);
50 const __m512 p3 = _mm512_set1_ps(0.05550410866482157f);
51 const __m512 p4 = _mm512_set1_ps(0.009618129107628477f);
53 __m512 poly = _mm512_fmadd_ps(p4, rx, p3);
54 poly = _mm512_fmadd_ps(poly, rx, p2);
55 poly = _mm512_fmadd_ps(poly, rx, p1);
56 poly = _mm512_fmadd_ps(poly, rx, p0);
58 __m512i ti_int = _mm512_cvtps_epi32(ti);
59 ti_int = _mm512_add_epi32(ti_int, _mm512_set1_epi32(127));
60 ti_int = _mm512_slli_epi32(ti_int, 23);
61 __m512 scale = _mm512_castsi512_ps(ti_int);
63 return _mm512_mul_ps(poly, scale);
67 static inline __m512 tanh512_fast(__m512 x) {
68 __m512 two = _mm512_set1_ps(2.0f);
69 __m512 one = _mm512_set1_ps(1.0f);
70 __m512 exp2x = exp512_fast(_mm512_mul_ps(two, x));
71 __m512 num = _mm512_sub_ps(exp2x, one);
72 __m512 den = _mm512_add_ps(exp2x, one);
73 return _mm512_div_ps(num, den);
78 static inline __m256 exp256_fast(__m256 x) {
79 x = _mm256_max_ps(x, _mm256_set1_ps(-88.0f));
80 x = _mm256_min_ps(x, _mm256_set1_ps(88.0f));
82 const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
83 const __m256 c1 = _mm256_set1_ps(0.693359375f);
84 const __m256 c2 = _mm256_set1_ps(-2.12194440e-4f);
86 __m256 t = _mm256_mul_ps(x, log2e);
87 __m256 ti = _mm256_round_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
89 __m256 rx = _mm256_sub_ps(x, _mm256_mul_ps(ti, c1));
90 rx = _mm256_sub_ps(rx, _mm256_mul_ps(ti, c2));
92 const __m256 p0 = _mm256_set1_ps(1.0f);
93 const __m256 p1 = _mm256_set1_ps(0.6931471805599453f);
94 const __m256 p2 = _mm256_set1_ps(0.24022650695910071f);
95 const __m256 p3 = _mm256_set1_ps(0.05550410866482157f);
96 const __m256 p4 = _mm256_set1_ps(0.009618129107628477f);
98 __m256 poly = _mm256_fmadd_ps(p4, rx, p3);
99 poly = _mm256_fmadd_ps(poly, rx, p2);
100 poly = _mm256_fmadd_ps(poly, rx, p1);
101 poly = _mm256_fmadd_ps(poly, rx, p0);
103 __m256i ti_int = _mm256_cvtps_epi32(ti);
104 ti_int = _mm256_add_epi32(ti_int, _mm256_set1_epi32(127));
105 ti_int = _mm256_slli_epi32(ti_int, 23);
106 __m256 scale = _mm256_castsi256_ps(ti_int);
108 return _mm256_mul_ps(poly, scale);
111 static inline __m256 tanh256_fast(__m256 x) {
112 __m256 two = _mm256_set1_ps(2.0f);
113 __m256 one = _mm256_set1_ps(1.0f);
114 __m256 exp2x = exp256_fast(_mm256_mul_ps(two, x));
115 __m256 num = _mm256_sub_ps(exp2x, one);
116 __m256 den = _mm256_add_ps(exp2x, one);
117 return _mm256_div_ps(num, den);
134 const float sqrt_2_over_pi = 0.7978845608f;
135 const float coeff = 0.044715f;
137 #if defined(__AVX512F__)
138 const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
139 const __m512 coeff_vec = _mm512_set1_ps(coeff);
140 const __m512 half_vec = _mm512_set1_ps(0.5f);
141 const __m512 one_vec = _mm512_set1_ps(1.0f);
144 for (; i + 16 <= n; i += 16) {
145 __m512 x = _mm512_loadu_ps(&data[i]);
146 __m512 x2 = _mm512_mul_ps(x, x);
147 __m512 x3 = _mm512_mul_ps(x2, x);
150 __m512 inner = _mm512_fmadd_ps(coeff_vec, x3, x);
151 inner = _mm512_mul_ps(sqrt_2_pi_vec, inner);
154 __m512 tanh_val = tanh512_fast(inner);
155 __m512 one_plus_tanh = _mm512_add_ps(one_vec, tanh_val);
156 __m512 result = _mm512_mul_ps(half_vec, _mm512_mul_ps(x, one_plus_tanh));
158 _mm512_storeu_ps(&data[i], result);
163 float x3 = x * x * x;
164 float inner = sqrt_2_over_pi * (x + coeff * x3);
165 data[i] = 0.5f * x * (1.0f + tanhf(inner));
168 #elif defined(__AVX2__)
169 const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
170 const __m256 coeff_vec = _mm256_set1_ps(coeff);
171 const __m256 half_vec = _mm256_set1_ps(0.5f);
172 const __m256 one_vec = _mm256_set1_ps(1.0f);
175 for (; i + 8 <= n; i += 8) {
176 __m256 x = _mm256_loadu_ps(&data[i]);
177 __m256 x2 = _mm256_mul_ps(x, x);
178 __m256 x3 = _mm256_mul_ps(x2, x);
181 __m256 inner = _mm256_fmadd_ps(coeff_vec, x3, x);
182 inner = _mm256_mul_ps(sqrt_2_pi_vec, inner);
185 __m256 tanh_val = tanh256_fast(inner);
186 __m256 one_plus_tanh = _mm256_add_ps(one_vec, tanh_val);
187 __m256 result = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, one_plus_tanh));
189 _mm256_storeu_ps(&data[i], result);
194 float x3 = x * x * x;
195 float inner = sqrt_2_over_pi * (x + coeff * x3);
196 data[i] = 0.5f * x * (1.0f + tanhf(inner));
199 #elif defined(__AVX__)
201 const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
202 const __m256 coeff_vec = _mm256_set1_ps(coeff);
203 const __m256 half_vec = _mm256_set1_ps(0.5f);
204 const __m256 one_vec = _mm256_set1_ps(1.0f);
210 for (; i + 8 <= n; i += 8) {
211 __m256 x = _mm256_loadu_ps(&data[i]);
212 __m256 x2 = _mm256_mul_ps(x, x);
213 __m256 x3 = _mm256_mul_ps(x2, x);
216 __m256 coeff_x3 = _mm256_mul_ps(coeff_vec, x3);
217 __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(x, coeff_x3));
220 _mm256_store_ps(inner_arr, inner);
221 for (
int j = 0; j < 8; ++j) {
222 tanh_arr[j] = tanhf(inner_arr[j]);
224 __m256 tanh_val = _mm256_load_ps(tanh_arr);
227 __m256 one_plus_tanh = _mm256_add_ps(one_vec, tanh_val);
228 __m256 result = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, one_plus_tanh));
230 _mm256_storeu_ps(&data[i], result);
235 float x3 = x * x * x;
236 float inner = sqrt_2_over_pi * (x + coeff * x3);
237 data[i] = 0.5f * x * (1.0f + tanhf(inner));
242 for (
size_t i = 0; i < n; ++i) {
244 float x3 = x * x * x;
245 float inner = sqrt_2_over_pi * (x + coeff * x3);
246 data[i] = 0.5f * x * (1.0f + tanhf(inner));
258 const float *d_output,
262 const float sqrt_2_over_pi = 0.7978845608f;
263 const float coeff = 0.044715f;
265 #if defined(__AVX512F__)
266 const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
267 const __m512 coeff_vec = _mm512_set1_ps(coeff);
268 const __m512 coeff3_vec = _mm512_set1_ps(3.0f * coeff);
269 const __m512 half_vec = _mm512_set1_ps(0.5f);
270 const __m512 one_vec = _mm512_set1_ps(1.0f);
273 for (; i + 16 <= n; i += 16) {
274 __m512 x = _mm512_loadu_ps(&input[i]);
275 __m512 dy = _mm512_loadu_ps(&d_output[i]);
277 __m512 x2 = _mm512_mul_ps(x, x);
278 __m512 x3 = _mm512_mul_ps(x2, x);
281 __m512 g = _mm512_fmadd_ps(coeff_vec, x3, x);
282 g = _mm512_mul_ps(sqrt_2_pi_vec, g);
284 __m512 tanh_g = tanh512_fast(g);
287 __m512 g_prime = _mm512_fmadd_ps(coeff3_vec, x2, one_vec);
288 g_prime = _mm512_mul_ps(sqrt_2_pi_vec, g_prime);
291 __m512 sech2_g = _mm512_fnmadd_ps(tanh_g, tanh_g, one_vec);
294 __m512 term1 = _mm512_mul_ps(half_vec, _mm512_add_ps(one_vec, tanh_g));
295 __m512 term2 = _mm512_mul_ps(half_vec, _mm512_mul_ps(x, _mm512_mul_ps(sech2_g, g_prime)));
296 __m512 gelu_deriv = _mm512_add_ps(term1, term2);
298 __m512 result = _mm512_mul_ps(dy, gelu_deriv);
299 _mm512_storeu_ps(&d_input[i], result);
304 float x3 = x * x * x;
305 float g = sqrt_2_over_pi * (x + coeff * x3);
306 float tanh_g = tanhf(g);
308 float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
309 float sech2_g = 1.0f - tanh_g * tanh_g;
310 float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
311 d_input[i] = d_output[i] * gelu_derivative;
314 #elif defined(__AVX2__)
315 const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
316 const __m256 coeff_vec = _mm256_set1_ps(coeff);
317 const __m256 coeff3_vec = _mm256_set1_ps(3.0f * coeff);
318 const __m256 half_vec = _mm256_set1_ps(0.5f);
319 const __m256 one_vec = _mm256_set1_ps(1.0f);
322 for (; i + 8 <= n; i += 8) {
323 __m256 x = _mm256_loadu_ps(&input[i]);
324 __m256 dy = _mm256_loadu_ps(&d_output[i]);
326 __m256 x2 = _mm256_mul_ps(x, x);
327 __m256 x3 = _mm256_mul_ps(x2, x);
330 __m256 g = _mm256_fmadd_ps(coeff_vec, x3, x);
331 g = _mm256_mul_ps(sqrt_2_pi_vec, g);
333 __m256 tanh_g = tanh256_fast(g);
336 __m256 g_prime = _mm256_fmadd_ps(coeff3_vec, x2, one_vec);
337 g_prime = _mm256_mul_ps(sqrt_2_pi_vec, g_prime);
340 __m256 sech2_g = _mm256_fnmadd_ps(tanh_g, tanh_g, one_vec);
343 __m256 term1 = _mm256_mul_ps(half_vec, _mm256_add_ps(one_vec, tanh_g));
344 __m256 term2 = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, _mm256_mul_ps(sech2_g, g_prime)));
345 __m256 gelu_deriv = _mm256_add_ps(term1, term2);
347 __m256 result = _mm256_mul_ps(dy, gelu_deriv);
348 _mm256_storeu_ps(&d_input[i], result);
353 float x3 = x * x * x;
354 float g = sqrt_2_over_pi * (x + coeff * x3);
355 float tanh_g = tanhf(g);
357 float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
358 float sech2_g = 1.0f - tanh_g * tanh_g;
359 float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
360 d_input[i] = d_output[i] * gelu_derivative;
363 #elif defined(__AVX__)
365 const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
366 const __m256 coeff_vec = _mm256_set1_ps(coeff);
367 const __m256 coeff3_vec = _mm256_set1_ps(3.0f * coeff);
368 const __m256 half_vec = _mm256_set1_ps(0.5f);
369 const __m256 one_vec = _mm256_set1_ps(1.0f);
375 for (; i + 8 <= n; i += 8) {
376 __m256 x = _mm256_loadu_ps(&input[i]);
377 __m256 dy = _mm256_loadu_ps(&d_output[i]);
379 __m256 x2 = _mm256_mul_ps(x, x);
380 __m256 x3 = _mm256_mul_ps(x2, x);
383 __m256 coeff_x3 = _mm256_mul_ps(coeff_vec, x3);
384 __m256 g = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(x, coeff_x3));
387 _mm256_store_ps(g_arr, g);
388 for (
int j = 0; j < 8; ++j) {
389 tanh_arr[j] = tanhf(g_arr[j]);
391 __m256 tanh_g = _mm256_load_ps(tanh_arr);
394 __m256 coeff3_x2 = _mm256_mul_ps(coeff3_vec, x2);
395 __m256 g_prime = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(one_vec, coeff3_x2));
398 __m256 tanh_g_sq = _mm256_mul_ps(tanh_g, tanh_g);
399 __m256 sech2_g = _mm256_sub_ps(one_vec, tanh_g_sq);
402 __m256 term1 = _mm256_mul_ps(half_vec, _mm256_add_ps(one_vec, tanh_g));
403 __m256 term2 = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, _mm256_mul_ps(sech2_g, g_prime)));
404 __m256 gelu_deriv = _mm256_add_ps(term1, term2);
406 __m256 result = _mm256_mul_ps(dy, gelu_deriv);
407 _mm256_storeu_ps(&d_input[i], result);
412 float x3 = x * x * x;
413 float g = sqrt_2_over_pi * (x + coeff * x3);
414 float tanh_g = tanhf(g);
416 float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
417 float sech2_g = 1.0f - tanh_g * tanh_g;
418 float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
419 d_input[i] = d_output[i] * gelu_derivative;
424 for (
size_t i = 0; i < n; ++i) {
427 float x3 = x * x * x;
428 float g = sqrt_2_over_pi * (x + coeff * x3);
429 float tanh_g = tanhf(g);
432 float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
434 float sech2_g = 1.0f - tanh_g * tanh_g;
435 float gelu_derivative =
436 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
438 d_input[i] = d_output[i] * gelu_derivative;
448 const float sqrt_2_over_pi = 0.7978845608f;
449 const float coeff = 0.044715f;
451 for (
size_t i = 0; i < n; ++i) {
453 float x3 = x * x * x;
454 float inner = sqrt_2_over_pi * (x + coeff * x3);
455 data[i] = 0.5f * x * (1.0f + tanhf(inner));
463 const float *d_output,
467 const float sqrt_2_over_pi = 0.7978845608f;
468 const float coeff = 0.044715f;
470 for (
size_t i = 0; i < n; ++i) {
472 float x3 = x * x * x;
473 float g = sqrt_2_over_pi * (x + coeff * x3);
474 float tanh_g = tanhf(g);
476 float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
477 float sech2_g = 1.0f - tanh_g * tanh_g;
478 float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
479 d_input[i] = d_output[i] * gelu_derivative;
487 const float *d_output,
491 const float beta = 1.702f;
493 #if defined(__AVX512F__)
494 const __m512 beta_vec = _mm512_set1_ps(beta);
495 const __m512 one_vec = _mm512_set1_ps(1.0f);
496 const __m512 neg_beta_vec = _mm512_set1_ps(-beta);
499 for (; i + 16 <= n; i += 16) {
500 __m512 x = _mm512_loadu_ps(&input[i]);
501 __m512 dy = _mm512_loadu_ps(&d_output[i]);
504 __m512 neg_beta_x = _mm512_mul_ps(neg_beta_vec, x);
505 __m512 exp_neg = exp512_fast(neg_beta_x);
506 __m512 s = _mm512_div_ps(one_vec, _mm512_add_ps(one_vec, exp_neg));
509 __m512 one_minus_s = _mm512_sub_ps(one_vec, s);
510 __m512 inner = _mm512_fmadd_ps(_mm512_mul_ps(x, one_minus_s), beta_vec, one_vec);
511 __m512 gelu_deriv = _mm512_mul_ps(s, inner);
513 __m512 result = _mm512_mul_ps(dy, gelu_deriv);
514 _mm512_storeu_ps(&d_input[i], result);
519 float s = 1.0f / (1.0f + expf(-beta * x));
520 float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
521 d_input[i] = d_output[i] * gelu_derivative;
524 #elif defined(__AVX2__)
525 const __m256 beta_vec = _mm256_set1_ps(beta);
526 const __m256 one_vec = _mm256_set1_ps(1.0f);
527 const __m256 neg_beta_vec = _mm256_set1_ps(-beta);
530 for (; i + 8 <= n; i += 8) {
531 __m256 x = _mm256_loadu_ps(&input[i]);
532 __m256 dy = _mm256_loadu_ps(&d_output[i]);
535 __m256 neg_beta_x = _mm256_mul_ps(neg_beta_vec, x);
536 __m256 exp_neg = exp256_fast(neg_beta_x);
537 __m256 s = _mm256_div_ps(one_vec, _mm256_add_ps(one_vec, exp_neg));
540 __m256 one_minus_s = _mm256_sub_ps(one_vec, s);
541 __m256 inner = _mm256_fmadd_ps(_mm256_mul_ps(x, one_minus_s), beta_vec, one_vec);
542 __m256 gelu_deriv = _mm256_mul_ps(s, inner);
544 __m256 result = _mm256_mul_ps(dy, gelu_deriv);
545 _mm256_storeu_ps(&d_input[i], result);
550 float s = 1.0f / (1.0f + expf(-beta * x));
551 float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
552 d_input[i] = d_output[i] * gelu_derivative;
555 #elif defined(__AVX__)
557 const __m256 beta_vec = _mm256_set1_ps(beta);
558 const __m256 one_vec = _mm256_set1_ps(1.0f);
559 const __m256 neg_beta_vec = _mm256_set1_ps(-beta);
565 for (; i + 8 <= n; i += 8) {
566 __m256 x = _mm256_loadu_ps(&input[i]);
567 __m256 dy = _mm256_loadu_ps(&d_output[i]);
570 __m256 neg_beta_x = _mm256_mul_ps(neg_beta_vec, x);
573 _mm256_store_ps(neg_beta_x_arr, neg_beta_x);
574 for (
int j = 0; j < 8; ++j) {
575 exp_arr[j] = expf(neg_beta_x_arr[j]);
577 __m256 exp_neg = _mm256_load_ps(exp_arr);
579 __m256 s = _mm256_div_ps(one_vec, _mm256_add_ps(one_vec, exp_neg));
582 __m256 one_minus_s = _mm256_sub_ps(one_vec, s);
583 __m256 x_one_minus_s = _mm256_mul_ps(x, one_minus_s);
584 __m256 x_one_minus_s_beta = _mm256_mul_ps(x_one_minus_s, beta_vec);
585 __m256 inner = _mm256_add_ps(one_vec, x_one_minus_s_beta);
586 __m256 gelu_deriv = _mm256_mul_ps(s, inner);
588 __m256 result = _mm256_mul_ps(dy, gelu_deriv);
589 _mm256_storeu_ps(&d_input[i], result);
594 float s = 1.0f / (1.0f + expf(-beta * x));
595 float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
596 d_input[i] = d_output[i] * gelu_derivative;
625 const float sqrt_2_over_pi = 0.7978845608f;
626 const float coeff = 0.044715f;
628 const int inner_dim = dim * 2;
630 #if defined(__AVX512F__)
631 const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
632 const __m512 coeff_vec = _mm512_set1_ps(coeff);
633 const __m512 half_vec = _mm512_set1_ps(0.5f);
634 const __m512 one_vec = _mm512_set1_ps(1.0f);
636 for (
int t = 0; t < tokens; ++t) {
637 const float *x_ptr = x + (size_t)t * inner_dim;
638 float *out_ptr = out + (size_t)t * dim;
642 for (; d + 32 <= dim; d += 32) {
644 __m512 a0 = _mm512_loadu_ps(&x_ptr[d]);
645 __m512 a1 = _mm512_loadu_ps(&x_ptr[d + 16]);
648 __m512 a0_sq = _mm512_mul_ps(a0, a0);
649 __m512 a0_cu = _mm512_mul_ps(a0_sq, a0);
650 __m512 a1_sq = _mm512_mul_ps(a1, a1);
651 __m512 a1_cu = _mm512_mul_ps(a1_sq, a1);
654 __m512 inner0 = _mm512_fmadd_ps(coeff_vec, a0_cu, a0);
655 __m512 inner1 = _mm512_fmadd_ps(coeff_vec, a1_cu, a1);
656 inner0 = _mm512_mul_ps(sqrt_2_pi_vec, inner0);
657 inner1 = _mm512_mul_ps(sqrt_2_pi_vec, inner1);
660 __m512 tanh0 = tanh512_fast(inner0);
661 __m512 tanh1 = tanh512_fast(inner1);
664 __m512 gelu0 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a0, _mm512_add_ps(one_vec, tanh0)));
665 __m512 gelu1 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a1, _mm512_add_ps(one_vec, tanh1)));
668 __m512 b0 = _mm512_loadu_ps(&x_ptr[dim + d]);
669 __m512 b1 = _mm512_loadu_ps(&x_ptr[dim + d + 16]);
672 _mm512_storeu_ps(&out_ptr[d], _mm512_mul_ps(gelu0, b0));
673 _mm512_storeu_ps(&out_ptr[d + 16], _mm512_mul_ps(gelu1, b1));
676 for (; d < dim; ++d) {
678 float b = x_ptr[dim + d];
679 float a3 = a * a * a;
680 float inner = sqrt_2_over_pi * (a + coeff * a3);
681 float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
682 out_ptr[d] = gelu_a * b;
686 #elif defined(__AVX2__)
687 const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
688 const __m256 coeff_vec = _mm256_set1_ps(coeff);
689 const __m256 half_vec = _mm256_set1_ps(0.5f);
690 const __m256 one_vec = _mm256_set1_ps(1.0f);
692 for (
int t = 0; t < tokens; ++t) {
693 const float *x_ptr = x + (size_t)t * inner_dim;
694 float *out_ptr = out + (size_t)t * dim;
697 for (; d + 16 <= dim; d += 16) {
699 __m256 a0 = _mm256_loadu_ps(&x_ptr[d]);
700 __m256 a1 = _mm256_loadu_ps(&x_ptr[d + 8]);
703 __m256 a0_sq = _mm256_mul_ps(a0, a0);
704 __m256 a0_cu = _mm256_mul_ps(a0_sq, a0);
705 __m256 a1_sq = _mm256_mul_ps(a1, a1);
706 __m256 a1_cu = _mm256_mul_ps(a1_sq, a1);
708 __m256 inner0 = _mm256_fmadd_ps(coeff_vec, a0_cu, a0);
709 __m256 inner1 = _mm256_fmadd_ps(coeff_vec, a1_cu, a1);
710 inner0 = _mm256_mul_ps(sqrt_2_pi_vec, inner0);
711 inner1 = _mm256_mul_ps(sqrt_2_pi_vec, inner1);
713 __m256 tanh0 = tanh256_fast(inner0);
714 __m256 tanh1 = tanh256_fast(inner1);
716 __m256 gelu0 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a0, _mm256_add_ps(one_vec, tanh0)));
717 __m256 gelu1 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a1, _mm256_add_ps(one_vec, tanh1)));
720 __m256 b0 = _mm256_loadu_ps(&x_ptr[dim + d]);
721 __m256 b1 = _mm256_loadu_ps(&x_ptr[dim + d + 8]);
723 _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu0, b0));
724 _mm256_storeu_ps(&out_ptr[d + 8], _mm256_mul_ps(gelu1, b1));
726 for (; d < dim; ++d) {
728 float b = x_ptr[dim + d];
729 float a3 = a * a * a;
730 float inner = sqrt_2_over_pi * (a + coeff * a3);
731 float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
732 out_ptr[d] = gelu_a * b;
736 #elif defined(__AVX__)
737 const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
738 const __m256 coeff_vec = _mm256_set1_ps(coeff);
739 const __m256 half_vec = _mm256_set1_ps(0.5f);
740 const __m256 one_vec = _mm256_set1_ps(1.0f);
745 for (
int t = 0; t < tokens; ++t) {
746 const float *x_ptr = x + (size_t)t * inner_dim;
747 float *out_ptr = out + (size_t)t * dim;
750 for (; d + 8 <= dim; d += 8) {
751 __m256 a = _mm256_loadu_ps(&x_ptr[d]);
752 __m256 a_sq = _mm256_mul_ps(a, a);
753 __m256 a_cu = _mm256_mul_ps(a_sq, a);
755 __m256 coeff_a_cu = _mm256_mul_ps(coeff_vec, a_cu);
756 __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(a, coeff_a_cu));
758 _mm256_store_ps(inner_arr, inner);
759 for (
int j = 0; j < 8; ++j) {
760 tanh_arr[j] = tanhf(inner_arr[j]);
762 __m256 tanh_val = _mm256_load_ps(tanh_arr);
764 __m256 gelu = _mm256_mul_ps(half_vec, _mm256_mul_ps(a, _mm256_add_ps(one_vec, tanh_val)));
765 __m256 b = _mm256_loadu_ps(&x_ptr[dim + d]);
767 _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu, b));
769 for (; d < dim; ++d) {
771 float b = x_ptr[dim + d];
772 float a3 = a * a * a;
773 float inner = sqrt_2_over_pi * (a + coeff * a3);
774 float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
775 out_ptr[d] = gelu_a * b;
781 for (
int t = 0; t < tokens; ++t) {
782 const float *x_ptr = x + (size_t)t * inner_dim;
783 float *out_ptr = out + (size_t)t * dim;
785 for (
int d = 0; d < dim; ++d) {
787 float b = x_ptr[dim + d];
788 float a3 = a * a * a;
789 float inner = sqrt_2_over_pi * (a + coeff * a3);
790 float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
791 out_ptr[d] = gelu_a * b;
815 if (!x || !out || !scratch)
return;
817 const size_t fp32_size = (size_t)tokens * (
size_t)dim;
818 const size_t input_size = fp32_size * 2;
819 float *fp32_input = scratch;
820 float *fp32_output = scratch + input_size;
849 const float sqrt_2_over_pi = 0.7978845608f;
850 const float coeff = 0.044715f;
852 const int inner_dim = dim * 2;
854 for (
int t = 0; t < tokens; ++t) {
855 const float *x_ptr = x + (size_t)t * inner_dim;
856 const float *d_out_ptr = d_out + (size_t)t * dim;
857 float *d_x_ptr = d_x + (size_t)t * inner_dim;
859 for (
int d = 0; d < dim; ++d) {
861 float b = x_ptr[dim + d];
862 float dout = d_out_ptr[d];
867 float g = sqrt_2_over_pi * (a + coeff * a3);
868 float tanh_g = tanhf(g);
869 float sech2_g = 1.0f - tanh_g * tanh_g;
870 float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * a2);
873 float d_gelu = 0.5f * (1.0f + tanh_g) + 0.5f * a * sech2_g * g_prime;
876 d_x_ptr[d] = dout * d_gelu * b;
879 float gelu_a = 0.5f * a * (1.0f + tanh_g);
880 d_x_ptr[dim + d] = dout * gelu_a;
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
void gelu_backward_exact(const float *input, const float *d_output, float *d_input, size_t n)
void gelu_exact_inplace(float *data, size_t n)
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int tokens, int dim)
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
void gelu_fast_inplace(float *data, size_t n)
void geglu_forward_bf16(const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch)
void gelu_backward_scalar(const float *input, const float *d_output, float *d_input, size_t n)
void gelu_backward_fast(const float *input, const float *d_output, float *d_input, size_t n)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)