← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gelu_kernels.c
Go to the documentation of this file.
1 /**
2  * @file gelu_kernels.c
3  * @brief GELU activation kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * GELU: y = x * 0.5 * (1 + erf(x / sqrt(2)))
15  * Fast approx: y = x * sigmoid(1.702 * x)
16  */
17 
18 #include <math.h>
19 #include <stddef.h>
20 #include <stdint.h>
21 #include <string.h>
22 
23 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
24 #include <immintrin.h>
25 #endif
26 
27 #include "bf16_utils.h"
28 
29 /* Fast vectorized exp approximation (same as softmax_kernels.c) */
30 #if defined(__AVX512F__)
31 static inline __m512 exp512_fast(__m512 x) {
32  // Clamp to avoid overflow/underflow
33  x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
34  x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
35 
36  const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
37  const __m512 c1 = _mm512_set1_ps(0.693359375f);
38  const __m512 c2 = _mm512_set1_ps(-2.12194440e-4f);
39 
40  __m512 t = _mm512_mul_ps(x, log2e);
41  __m512 ti = _mm512_roundscale_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
42 
43  __m512 rx = _mm512_sub_ps(x, _mm512_mul_ps(ti, c1));
44  rx = _mm512_sub_ps(rx, _mm512_mul_ps(ti, c2));
45 
46  // Polynomial approximation
47  const __m512 p0 = _mm512_set1_ps(1.0f);
48  const __m512 p1 = _mm512_set1_ps(0.6931471805599453f);
49  const __m512 p2 = _mm512_set1_ps(0.24022650695910071f);
50  const __m512 p3 = _mm512_set1_ps(0.05550410866482157f);
51  const __m512 p4 = _mm512_set1_ps(0.009618129107628477f);
52 
53  __m512 poly = _mm512_fmadd_ps(p4, rx, p3);
54  poly = _mm512_fmadd_ps(poly, rx, p2);
55  poly = _mm512_fmadd_ps(poly, rx, p1);
56  poly = _mm512_fmadd_ps(poly, rx, p0);
57 
58  __m512i ti_int = _mm512_cvtps_epi32(ti);
59  ti_int = _mm512_add_epi32(ti_int, _mm512_set1_epi32(127));
60  ti_int = _mm512_slli_epi32(ti_int, 23);
61  __m512 scale = _mm512_castsi512_ps(ti_int);
62 
63  return _mm512_mul_ps(poly, scale);
64 }
65 
66 // Fast vectorized tanh: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
67 static inline __m512 tanh512_fast(__m512 x) {
68  __m512 two = _mm512_set1_ps(2.0f);
69  __m512 one = _mm512_set1_ps(1.0f);
70  __m512 exp2x = exp512_fast(_mm512_mul_ps(two, x));
71  __m512 num = _mm512_sub_ps(exp2x, one);
72  __m512 den = _mm512_add_ps(exp2x, one);
73  return _mm512_div_ps(num, den);
74 }
75 #endif
76 
77 #if defined(__AVX2__)
78 static inline __m256 exp256_fast(__m256 x) {
79  x = _mm256_max_ps(x, _mm256_set1_ps(-88.0f));
80  x = _mm256_min_ps(x, _mm256_set1_ps(88.0f));
81 
82  const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
83  const __m256 c1 = _mm256_set1_ps(0.693359375f);
84  const __m256 c2 = _mm256_set1_ps(-2.12194440e-4f);
85 
86  __m256 t = _mm256_mul_ps(x, log2e);
87  __m256 ti = _mm256_round_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
88 
89  __m256 rx = _mm256_sub_ps(x, _mm256_mul_ps(ti, c1));
90  rx = _mm256_sub_ps(rx, _mm256_mul_ps(ti, c2));
91 
92  const __m256 p0 = _mm256_set1_ps(1.0f);
93  const __m256 p1 = _mm256_set1_ps(0.6931471805599453f);
94  const __m256 p2 = _mm256_set1_ps(0.24022650695910071f);
95  const __m256 p3 = _mm256_set1_ps(0.05550410866482157f);
96  const __m256 p4 = _mm256_set1_ps(0.009618129107628477f);
97 
98  __m256 poly = _mm256_fmadd_ps(p4, rx, p3);
99  poly = _mm256_fmadd_ps(poly, rx, p2);
100  poly = _mm256_fmadd_ps(poly, rx, p1);
101  poly = _mm256_fmadd_ps(poly, rx, p0);
102 
103  __m256i ti_int = _mm256_cvtps_epi32(ti);
104  ti_int = _mm256_add_epi32(ti_int, _mm256_set1_epi32(127));
105  ti_int = _mm256_slli_epi32(ti_int, 23);
106  __m256 scale = _mm256_castsi256_ps(ti_int);
107 
108  return _mm256_mul_ps(poly, scale);
109 }
110 
111 static inline __m256 tanh256_fast(__m256 x) {
112  __m256 two = _mm256_set1_ps(2.0f);
113  __m256 one = _mm256_set1_ps(1.0f);
114  __m256 exp2x = exp256_fast(_mm256_mul_ps(two, x));
115  __m256 num = _mm256_sub_ps(exp2x, one);
116  __m256 den = _mm256_add_ps(exp2x, one);
117  return _mm256_div_ps(num, den);
118 }
119 #endif
120 
121 /**
122  * GELU activation forward (fast approximation, in-place)
123  * @test test_gelu.py::TestGELUForward::test_gelu_fast_inplace
124  * @test test_gelu.py::TestGELUForward::test_gelu_vs_exact
125  * @test test_parity.py::test_gelu_parity
126  *
127  * Fast GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
128  * In-place on contiguous buffer.
129  *
130  * After changes: make test && make llamacpp-parity-full
131  */
132 void gelu_fast_inplace(float *data, size_t n)
133 {
134  const float sqrt_2_over_pi = 0.7978845608f;
135  const float coeff = 0.044715f;
136 
137 #if defined(__AVX512F__)
138  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
139  const __m512 coeff_vec = _mm512_set1_ps(coeff);
140  const __m512 half_vec = _mm512_set1_ps(0.5f);
141  const __m512 one_vec = _mm512_set1_ps(1.0f);
142 
143  size_t i = 0;
144  for (; i + 16 <= n; i += 16) {
145  __m512 x = _mm512_loadu_ps(&data[i]);
146  __m512 x2 = _mm512_mul_ps(x, x);
147  __m512 x3 = _mm512_mul_ps(x2, x);
148 
149  // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
150  __m512 inner = _mm512_fmadd_ps(coeff_vec, x3, x);
151  inner = _mm512_mul_ps(sqrt_2_pi_vec, inner);
152 
153  // result = 0.5 * x * (1 + tanh(inner))
154  __m512 tanh_val = tanh512_fast(inner);
155  __m512 one_plus_tanh = _mm512_add_ps(one_vec, tanh_val);
156  __m512 result = _mm512_mul_ps(half_vec, _mm512_mul_ps(x, one_plus_tanh));
157 
158  _mm512_storeu_ps(&data[i], result);
159  }
160  // Handle remaining elements
161  for (; i < n; ++i) {
162  float x = data[i];
163  float x3 = x * x * x;
164  float inner = sqrt_2_over_pi * (x + coeff * x3);
165  data[i] = 0.5f * x * (1.0f + tanhf(inner));
166  }
167 
168 #elif defined(__AVX2__)
169  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
170  const __m256 coeff_vec = _mm256_set1_ps(coeff);
171  const __m256 half_vec = _mm256_set1_ps(0.5f);
172  const __m256 one_vec = _mm256_set1_ps(1.0f);
173 
174  size_t i = 0;
175  for (; i + 8 <= n; i += 8) {
176  __m256 x = _mm256_loadu_ps(&data[i]);
177  __m256 x2 = _mm256_mul_ps(x, x);
178  __m256 x3 = _mm256_mul_ps(x2, x);
179 
180  // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
181  __m256 inner = _mm256_fmadd_ps(coeff_vec, x3, x);
182  inner = _mm256_mul_ps(sqrt_2_pi_vec, inner);
183 
184  // result = 0.5 * x * (1 + tanh(inner))
185  __m256 tanh_val = tanh256_fast(inner);
186  __m256 one_plus_tanh = _mm256_add_ps(one_vec, tanh_val);
187  __m256 result = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, one_plus_tanh));
188 
189  _mm256_storeu_ps(&data[i], result);
190  }
191  // Handle remaining elements
192  for (; i < n; ++i) {
193  float x = data[i];
194  float x3 = x * x * x;
195  float inner = sqrt_2_over_pi * (x + coeff * x3);
196  data[i] = 0.5f * x * (1.0f + tanhf(inner));
197  }
198 
199 #elif defined(__AVX__)
200  // AVX1: Vectorize arithmetic, use scalar tanh
201  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
202  const __m256 coeff_vec = _mm256_set1_ps(coeff);
203  const __m256 half_vec = _mm256_set1_ps(0.5f);
204  const __m256 one_vec = _mm256_set1_ps(1.0f);
205 
206  size_t i = 0;
207  float inner_arr[8] __attribute__((aligned(32)));
208  float tanh_arr[8] __attribute__((aligned(32)));
209 
210  for (; i + 8 <= n; i += 8) {
211  __m256 x = _mm256_loadu_ps(&data[i]);
212  __m256 x2 = _mm256_mul_ps(x, x);
213  __m256 x3 = _mm256_mul_ps(x2, x);
214 
215  // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
216  __m256 coeff_x3 = _mm256_mul_ps(coeff_vec, x3);
217  __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(x, coeff_x3));
218 
219  // Compute tanh scalarly
220  _mm256_store_ps(inner_arr, inner);
221  for (int j = 0; j < 8; ++j) {
222  tanh_arr[j] = tanhf(inner_arr[j]);
223  }
224  __m256 tanh_val = _mm256_load_ps(tanh_arr);
225 
226  // result = 0.5 * x * (1 + tanh(inner))
227  __m256 one_plus_tanh = _mm256_add_ps(one_vec, tanh_val);
228  __m256 result = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, one_plus_tanh));
229 
230  _mm256_storeu_ps(&data[i], result);
231  }
232  // Handle remaining elements
233  for (; i < n; ++i) {
234  float x = data[i];
235  float x3 = x * x * x;
236  float inner = sqrt_2_over_pi * (x + coeff * x3);
237  data[i] = 0.5f * x * (1.0f + tanhf(inner));
238  }
239 
240 #else
241  // Scalar fallback
242  for (size_t i = 0; i < n; ++i) {
243  float x = data[i];
244  float x3 = x * x * x;
245  float inner = sqrt_2_over_pi * (x + coeff * x3);
246  data[i] = 0.5f * x * (1.0f + tanhf(inner));
247  }
248 #endif
249 }
250 
251 // Exact GELU backward using the tanh-based approximation derivative, adapted
252 // from C-Transformer's backward_gelu. Operates element-wise on contiguous
253 // buffers.
254 // Derivative: d/dx GELU(x) = 0.5 * (1 + tanh(g)) + 0.5 * x * sech^2(g) * g'
255 // where g = sqrt(2/pi) * (x + 0.044715 * x^3)
256 // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
257 void gelu_backward_exact(const float *input,
258  const float *d_output,
259  float *d_input,
260  size_t n)
261 {
262  const float sqrt_2_over_pi = 0.7978845608f;
263  const float coeff = 0.044715f;
264 
265 #if defined(__AVX512F__)
266  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
267  const __m512 coeff_vec = _mm512_set1_ps(coeff);
268  const __m512 coeff3_vec = _mm512_set1_ps(3.0f * coeff);
269  const __m512 half_vec = _mm512_set1_ps(0.5f);
270  const __m512 one_vec = _mm512_set1_ps(1.0f);
271 
272  size_t i = 0;
273  for (; i + 16 <= n; i += 16) {
274  __m512 x = _mm512_loadu_ps(&input[i]);
275  __m512 dy = _mm512_loadu_ps(&d_output[i]);
276 
277  __m512 x2 = _mm512_mul_ps(x, x);
278  __m512 x3 = _mm512_mul_ps(x2, x);
279 
280  // g = sqrt(2/pi) * (x + 0.044715 * x^3)
281  __m512 g = _mm512_fmadd_ps(coeff_vec, x3, x);
282  g = _mm512_mul_ps(sqrt_2_pi_vec, g);
283 
284  __m512 tanh_g = tanh512_fast(g);
285 
286  // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
287  __m512 g_prime = _mm512_fmadd_ps(coeff3_vec, x2, one_vec);
288  g_prime = _mm512_mul_ps(sqrt_2_pi_vec, g_prime);
289 
290  // sech^2(g) = 1 - tanh^2(g)
291  __m512 sech2_g = _mm512_fnmadd_ps(tanh_g, tanh_g, one_vec);
292 
293  // gelu_derivative = 0.5 * (1 + tanh_g) + 0.5 * x * sech2_g * g_prime
294  __m512 term1 = _mm512_mul_ps(half_vec, _mm512_add_ps(one_vec, tanh_g));
295  __m512 term2 = _mm512_mul_ps(half_vec, _mm512_mul_ps(x, _mm512_mul_ps(sech2_g, g_prime)));
296  __m512 gelu_deriv = _mm512_add_ps(term1, term2);
297 
298  __m512 result = _mm512_mul_ps(dy, gelu_deriv);
299  _mm512_storeu_ps(&d_input[i], result);
300  }
301  // Handle remaining elements
302  for (; i < n; ++i) {
303  float x = input[i];
304  float x3 = x * x * x;
305  float g = sqrt_2_over_pi * (x + coeff * x3);
306  float tanh_g = tanhf(g);
307  float x2 = x * x;
308  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
309  float sech2_g = 1.0f - tanh_g * tanh_g;
310  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
311  d_input[i] = d_output[i] * gelu_derivative;
312  }
313 
314 #elif defined(__AVX2__)
315  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
316  const __m256 coeff_vec = _mm256_set1_ps(coeff);
317  const __m256 coeff3_vec = _mm256_set1_ps(3.0f * coeff);
318  const __m256 half_vec = _mm256_set1_ps(0.5f);
319  const __m256 one_vec = _mm256_set1_ps(1.0f);
320 
321  size_t i = 0;
322  for (; i + 8 <= n; i += 8) {
323  __m256 x = _mm256_loadu_ps(&input[i]);
324  __m256 dy = _mm256_loadu_ps(&d_output[i]);
325 
326  __m256 x2 = _mm256_mul_ps(x, x);
327  __m256 x3 = _mm256_mul_ps(x2, x);
328 
329  // g = sqrt(2/pi) * (x + 0.044715 * x^3)
330  __m256 g = _mm256_fmadd_ps(coeff_vec, x3, x);
331  g = _mm256_mul_ps(sqrt_2_pi_vec, g);
332 
333  __m256 tanh_g = tanh256_fast(g);
334 
335  // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
336  __m256 g_prime = _mm256_fmadd_ps(coeff3_vec, x2, one_vec);
337  g_prime = _mm256_mul_ps(sqrt_2_pi_vec, g_prime);
338 
339  // sech^2(g) = 1 - tanh^2(g)
340  __m256 sech2_g = _mm256_fnmadd_ps(tanh_g, tanh_g, one_vec);
341 
342  // gelu_derivative = 0.5 * (1 + tanh_g) + 0.5 * x * sech2_g * g_prime
343  __m256 term1 = _mm256_mul_ps(half_vec, _mm256_add_ps(one_vec, tanh_g));
344  __m256 term2 = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, _mm256_mul_ps(sech2_g, g_prime)));
345  __m256 gelu_deriv = _mm256_add_ps(term1, term2);
346 
347  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
348  _mm256_storeu_ps(&d_input[i], result);
349  }
350  // Handle remaining elements
351  for (; i < n; ++i) {
352  float x = input[i];
353  float x3 = x * x * x;
354  float g = sqrt_2_over_pi * (x + coeff * x3);
355  float tanh_g = tanhf(g);
356  float x2 = x * x;
357  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
358  float sech2_g = 1.0f - tanh_g * tanh_g;
359  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
360  d_input[i] = d_output[i] * gelu_derivative;
361  }
362 
363 #elif defined(__AVX__)
364  // AVX1: Vectorize arithmetic, use scalar tanh
365  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
366  const __m256 coeff_vec = _mm256_set1_ps(coeff);
367  const __m256 coeff3_vec = _mm256_set1_ps(3.0f * coeff);
368  const __m256 half_vec = _mm256_set1_ps(0.5f);
369  const __m256 one_vec = _mm256_set1_ps(1.0f);
370 
371  size_t i = 0;
372  float g_arr[8] __attribute__((aligned(32)));
373  float tanh_arr[8] __attribute__((aligned(32)));
374 
375  for (; i + 8 <= n; i += 8) {
376  __m256 x = _mm256_loadu_ps(&input[i]);
377  __m256 dy = _mm256_loadu_ps(&d_output[i]);
378 
379  __m256 x2 = _mm256_mul_ps(x, x);
380  __m256 x3 = _mm256_mul_ps(x2, x);
381 
382  // g = sqrt(2/pi) * (x + 0.044715 * x^3)
383  __m256 coeff_x3 = _mm256_mul_ps(coeff_vec, x3);
384  __m256 g = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(x, coeff_x3));
385 
386  // Compute tanh scalarly
387  _mm256_store_ps(g_arr, g);
388  for (int j = 0; j < 8; ++j) {
389  tanh_arr[j] = tanhf(g_arr[j]);
390  }
391  __m256 tanh_g = _mm256_load_ps(tanh_arr);
392 
393  // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
394  __m256 coeff3_x2 = _mm256_mul_ps(coeff3_vec, x2);
395  __m256 g_prime = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(one_vec, coeff3_x2));
396 
397  // sech^2(g) = 1 - tanh^2(g)
398  __m256 tanh_g_sq = _mm256_mul_ps(tanh_g, tanh_g);
399  __m256 sech2_g = _mm256_sub_ps(one_vec, tanh_g_sq);
400 
401  // gelu_derivative = 0.5 * (1 + tanh_g) + 0.5 * x * sech2_g * g_prime
402  __m256 term1 = _mm256_mul_ps(half_vec, _mm256_add_ps(one_vec, tanh_g));
403  __m256 term2 = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, _mm256_mul_ps(sech2_g, g_prime)));
404  __m256 gelu_deriv = _mm256_add_ps(term1, term2);
405 
406  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
407  _mm256_storeu_ps(&d_input[i], result);
408  }
409  // Handle remaining elements
410  for (; i < n; ++i) {
411  float x = input[i];
412  float x3 = x * x * x;
413  float g = sqrt_2_over_pi * (x + coeff * x3);
414  float tanh_g = tanhf(g);
415  float x2 = x * x;
416  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
417  float sech2_g = 1.0f - tanh_g * tanh_g;
418  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
419  d_input[i] = d_output[i] * gelu_derivative;
420  }
421 
422 #else
423  // Scalar fallback
424  for (size_t i = 0; i < n; ++i) {
425  float x = input[i];
426 
427  float x3 = x * x * x;
428  float g = sqrt_2_over_pi * (x + coeff * x3);
429  float tanh_g = tanhf(g);
430 
431  float x2 = x * x;
432  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
433 
434  float sech2_g = 1.0f - tanh_g * tanh_g;
435  float gelu_derivative =
436  0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
437 
438  d_input[i] = d_output[i] * gelu_derivative;
439  }
440 #endif
441 }
442 
443 // Scalar-only exact GELU forward using standard library tanhf.
444 // This is slower than gelu_fast_inplace but provides maximum accuracy.
445 // Used by BF16 wrapper where conversion overhead dominates anyway.
446 void gelu_exact_inplace(float *data, size_t n)
447 {
448  const float sqrt_2_over_pi = 0.7978845608f;
449  const float coeff = 0.044715f;
450 
451  for (size_t i = 0; i < n; ++i) {
452  float x = data[i];
453  float x3 = x * x * x;
454  float inner = sqrt_2_over_pi * (x + coeff * x3);
455  data[i] = 0.5f * x * (1.0f + tanhf(inner));
456  }
457 }
458 
459 // Scalar-only exact GELU backward using standard library tanhf.
460 // This is slower than gelu_backward_exact but provides maximum accuracy.
461 // Used by BF16 wrapper where conversion overhead dominates anyway.
462 void gelu_backward_scalar(const float *input,
463  const float *d_output,
464  float *d_input,
465  size_t n)
466 {
467  const float sqrt_2_over_pi = 0.7978845608f;
468  const float coeff = 0.044715f;
469 
470  for (size_t i = 0; i < n; ++i) {
471  float x = input[i];
472  float x3 = x * x * x;
473  float g = sqrt_2_over_pi * (x + coeff * x3);
474  float tanh_g = tanhf(g);
475  float x2 = x * x;
476  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
477  float sech2_g = 1.0f - tanh_g * tanh_g;
478  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
479  d_input[i] = d_output[i] * gelu_derivative;
480  }
481 }
482 
483 // Fast approximate GELU backward, adapted from C-Transformer's backward_gelu_fast.
484 // Uses sigmoid approximation: GELU(x) ≈ x * sigmoid(1.702 * x)
485 // Derivative: s * (1 + x * (1 - s) * 1.702) where s = sigmoid(1.702 * x)
486 void gelu_backward_fast(const float *input,
487  const float *d_output,
488  float *d_input,
489  size_t n)
490 {
491  const float beta = 1.702f;
492 
493 #if defined(__AVX512F__)
494  const __m512 beta_vec = _mm512_set1_ps(beta);
495  const __m512 one_vec = _mm512_set1_ps(1.0f);
496  const __m512 neg_beta_vec = _mm512_set1_ps(-beta);
497 
498  size_t i = 0;
499  for (; i + 16 <= n; i += 16) {
500  __m512 x = _mm512_loadu_ps(&input[i]);
501  __m512 dy = _mm512_loadu_ps(&d_output[i]);
502 
503  // s = sigmoid(beta * x) = 1 / (1 + exp(-beta * x))
504  __m512 neg_beta_x = _mm512_mul_ps(neg_beta_vec, x);
505  __m512 exp_neg = exp512_fast(neg_beta_x);
506  __m512 s = _mm512_div_ps(one_vec, _mm512_add_ps(one_vec, exp_neg));
507 
508  // gelu_derivative = s * (1 + x * (1 - s) * beta)
509  __m512 one_minus_s = _mm512_sub_ps(one_vec, s);
510  __m512 inner = _mm512_fmadd_ps(_mm512_mul_ps(x, one_minus_s), beta_vec, one_vec);
511  __m512 gelu_deriv = _mm512_mul_ps(s, inner);
512 
513  __m512 result = _mm512_mul_ps(dy, gelu_deriv);
514  _mm512_storeu_ps(&d_input[i], result);
515  }
516  // Handle remaining elements
517  for (; i < n; ++i) {
518  float x = input[i];
519  float s = 1.0f / (1.0f + expf(-beta * x));
520  float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
521  d_input[i] = d_output[i] * gelu_derivative;
522  }
523 
524 #elif defined(__AVX2__)
525  const __m256 beta_vec = _mm256_set1_ps(beta);
526  const __m256 one_vec = _mm256_set1_ps(1.0f);
527  const __m256 neg_beta_vec = _mm256_set1_ps(-beta);
528 
529  size_t i = 0;
530  for (; i + 8 <= n; i += 8) {
531  __m256 x = _mm256_loadu_ps(&input[i]);
532  __m256 dy = _mm256_loadu_ps(&d_output[i]);
533 
534  // s = sigmoid(beta * x) = 1 / (1 + exp(-beta * x))
535  __m256 neg_beta_x = _mm256_mul_ps(neg_beta_vec, x);
536  __m256 exp_neg = exp256_fast(neg_beta_x);
537  __m256 s = _mm256_div_ps(one_vec, _mm256_add_ps(one_vec, exp_neg));
538 
539  // gelu_derivative = s * (1 + x * (1 - s) * beta)
540  __m256 one_minus_s = _mm256_sub_ps(one_vec, s);
541  __m256 inner = _mm256_fmadd_ps(_mm256_mul_ps(x, one_minus_s), beta_vec, one_vec);
542  __m256 gelu_deriv = _mm256_mul_ps(s, inner);
543 
544  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
545  _mm256_storeu_ps(&d_input[i], result);
546  }
547  // Handle remaining elements
548  for (; i < n; ++i) {
549  float x = input[i];
550  float s = 1.0f / (1.0f + expf(-beta * x));
551  float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
552  d_input[i] = d_output[i] * gelu_derivative;
553  }
554 
555 #elif defined(__AVX__)
556  // AVX1: Vectorize arithmetic, use scalar exp
557  const __m256 beta_vec = _mm256_set1_ps(beta);
558  const __m256 one_vec = _mm256_set1_ps(1.0f);
559  const __m256 neg_beta_vec = _mm256_set1_ps(-beta);
560 
561  size_t i = 0;
562  float neg_beta_x_arr[8] __attribute__((aligned(32)));
563  float exp_arr[8] __attribute__((aligned(32)));
564 
565  for (; i + 8 <= n; i += 8) {
566  __m256 x = _mm256_loadu_ps(&input[i]);
567  __m256 dy = _mm256_loadu_ps(&d_output[i]);
568 
569  // s = sigmoid(beta * x) = 1 / (1 + exp(-beta * x))
570  __m256 neg_beta_x = _mm256_mul_ps(neg_beta_vec, x);
571 
572  // Compute exp scalarly
573  _mm256_store_ps(neg_beta_x_arr, neg_beta_x);
574  for (int j = 0; j < 8; ++j) {
575  exp_arr[j] = expf(neg_beta_x_arr[j]);
576  }
577  __m256 exp_neg = _mm256_load_ps(exp_arr);
578 
579  __m256 s = _mm256_div_ps(one_vec, _mm256_add_ps(one_vec, exp_neg));
580 
581  // gelu_derivative = s * (1 + x * (1 - s) * beta)
582  __m256 one_minus_s = _mm256_sub_ps(one_vec, s);
583  __m256 x_one_minus_s = _mm256_mul_ps(x, one_minus_s);
584  __m256 x_one_minus_s_beta = _mm256_mul_ps(x_one_minus_s, beta_vec);
585  __m256 inner = _mm256_add_ps(one_vec, x_one_minus_s_beta);
586  __m256 gelu_deriv = _mm256_mul_ps(s, inner);
587 
588  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
589  _mm256_storeu_ps(&d_input[i], result);
590  }
591  // Handle remaining elements
592  for (; i < n; ++i) {
593  float x = input[i];
594  float s = 1.0f / (1.0f + expf(-beta * x));
595  float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
596  d_input[i] = d_output[i] * gelu_derivative;
597  }
598 #endif
599 }
600 
601 // ============================================================================
602 // GeGLU Kernels - Gated GELU (used in LLaMA, Mistral, etc.)
603 // ============================================================================
604 //
605 // GeGLU: out = GELU(a) * b where input x = [a, b] along the last dimension
606 // Input shape: [tokens, 2 * dim]
607 // Output shape: [tokens, dim]
608 //
609 // The input is split along the last dimension:
610 // - First half: a = x[..., :dim] -> GELU(a)
611 // - Second half: b = x[..., dim:] -> element-wise multiply
612 // - out = GELU(a) * b
613 
614 /**
615  * GeGLU forward pass (fp32)
616  * @test test_geglu.py::TestGeGLU::test_geglu_forward_fp32
617  *
618  * Computes out = GELU(a) * b where x = [a, b] along last dimension.
619  * Input shape: [tokens, 2 * dim], Output shape: [tokens, dim]
620  *
621  * After changes: make test
622  */
623 void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
624 {
625  const float sqrt_2_over_pi = 0.7978845608f;
626  const float coeff = 0.044715f;
627 
628  const int inner_dim = dim * 2;
629 
630 #if defined(__AVX512F__)
631  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
632  const __m512 coeff_vec = _mm512_set1_ps(coeff);
633  const __m512 half_vec = _mm512_set1_ps(0.5f);
634  const __m512 one_vec = _mm512_set1_ps(1.0f);
635 
636  for (int t = 0; t < tokens; ++t) {
637  const float *x_ptr = x + (size_t)t * inner_dim;
638  float *out_ptr = out + (size_t)t * dim;
639 
640  int d = 0;
641  // Process first half (a) with GELU, second half (b) directly
642  for (; d + 32 <= dim; d += 32) {
643  // Load a (first half of inner_dim)
644  __m512 a0 = _mm512_loadu_ps(&x_ptr[d]);
645  __m512 a1 = _mm512_loadu_ps(&x_ptr[d + 16]);
646 
647  // Compute GELU(a)
648  __m512 a0_sq = _mm512_mul_ps(a0, a0);
649  __m512 a0_cu = _mm512_mul_ps(a0_sq, a0);
650  __m512 a1_sq = _mm512_mul_ps(a1, a1);
651  __m512 a1_cu = _mm512_mul_ps(a1_sq, a1);
652 
653  // inner = sqrt(2/pi) * (a + 0.044715 * a^3)
654  __m512 inner0 = _mm512_fmadd_ps(coeff_vec, a0_cu, a0);
655  __m512 inner1 = _mm512_fmadd_ps(coeff_vec, a1_cu, a1);
656  inner0 = _mm512_mul_ps(sqrt_2_pi_vec, inner0);
657  inner1 = _mm512_mul_ps(sqrt_2_pi_vec, inner1);
658 
659  // tanh(inner)
660  __m512 tanh0 = tanh512_fast(inner0);
661  __m512 tanh1 = tanh512_fast(inner1);
662 
663  // GELU = 0.5 * a * (1 + tanh)
664  __m512 gelu0 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a0, _mm512_add_ps(one_vec, tanh0)));
665  __m512 gelu1 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a1, _mm512_add_ps(one_vec, tanh1)));
666 
667  // Load b (second half of inner_dim)
668  __m512 b0 = _mm512_loadu_ps(&x_ptr[dim + d]);
669  __m512 b1 = _mm512_loadu_ps(&x_ptr[dim + d + 16]);
670 
671  // out = GELU(a) * b
672  _mm512_storeu_ps(&out_ptr[d], _mm512_mul_ps(gelu0, b0));
673  _mm512_storeu_ps(&out_ptr[d + 16], _mm512_mul_ps(gelu1, b1));
674  }
675  // Handle remaining
676  for (; d < dim; ++d) {
677  float a = x_ptr[d];
678  float b = x_ptr[dim + d];
679  float a3 = a * a * a;
680  float inner = sqrt_2_over_pi * (a + coeff * a3);
681  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
682  out_ptr[d] = gelu_a * b;
683  }
684  }
685 
686 #elif defined(__AVX2__)
687  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
688  const __m256 coeff_vec = _mm256_set1_ps(coeff);
689  const __m256 half_vec = _mm256_set1_ps(0.5f);
690  const __m256 one_vec = _mm256_set1_ps(1.0f);
691 
692  for (int t = 0; t < tokens; ++t) {
693  const float *x_ptr = x + (size_t)t * inner_dim;
694  float *out_ptr = out + (size_t)t * dim;
695 
696  int d = 0;
697  for (; d + 16 <= dim; d += 16) {
698  // Load a
699  __m256 a0 = _mm256_loadu_ps(&x_ptr[d]);
700  __m256 a1 = _mm256_loadu_ps(&x_ptr[d + 8]);
701 
702  // GELU(a)
703  __m256 a0_sq = _mm256_mul_ps(a0, a0);
704  __m256 a0_cu = _mm256_mul_ps(a0_sq, a0);
705  __m256 a1_sq = _mm256_mul_ps(a1, a1);
706  __m256 a1_cu = _mm256_mul_ps(a1_sq, a1);
707 
708  __m256 inner0 = _mm256_fmadd_ps(coeff_vec, a0_cu, a0);
709  __m256 inner1 = _mm256_fmadd_ps(coeff_vec, a1_cu, a1);
710  inner0 = _mm256_mul_ps(sqrt_2_pi_vec, inner0);
711  inner1 = _mm256_mul_ps(sqrt_2_pi_vec, inner1);
712 
713  __m256 tanh0 = tanh256_fast(inner0);
714  __m256 tanh1 = tanh256_fast(inner1);
715 
716  __m256 gelu0 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a0, _mm256_add_ps(one_vec, tanh0)));
717  __m256 gelu1 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a1, _mm256_add_ps(one_vec, tanh1)));
718 
719  // b
720  __m256 b0 = _mm256_loadu_ps(&x_ptr[dim + d]);
721  __m256 b1 = _mm256_loadu_ps(&x_ptr[dim + d + 8]);
722 
723  _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu0, b0));
724  _mm256_storeu_ps(&out_ptr[d + 8], _mm256_mul_ps(gelu1, b1));
725  }
726  for (; d < dim; ++d) {
727  float a = x_ptr[d];
728  float b = x_ptr[dim + d];
729  float a3 = a * a * a;
730  float inner = sqrt_2_over_pi * (a + coeff * a3);
731  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
732  out_ptr[d] = gelu_a * b;
733  }
734  }
735 
736 #elif defined(__AVX__)
737  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
738  const __m256 coeff_vec = _mm256_set1_ps(coeff);
739  const __m256 half_vec = _mm256_set1_ps(0.5f);
740  const __m256 one_vec = _mm256_set1_ps(1.0f);
741 
742  float inner_arr[8] __attribute__((aligned(32)));
743  float tanh_arr[8] __attribute__((aligned(32)));
744 
745  for (int t = 0; t < tokens; ++t) {
746  const float *x_ptr = x + (size_t)t * inner_dim;
747  float *out_ptr = out + (size_t)t * dim;
748 
749  int d = 0;
750  for (; d + 8 <= dim; d += 8) {
751  __m256 a = _mm256_loadu_ps(&x_ptr[d]);
752  __m256 a_sq = _mm256_mul_ps(a, a);
753  __m256 a_cu = _mm256_mul_ps(a_sq, a);
754 
755  __m256 coeff_a_cu = _mm256_mul_ps(coeff_vec, a_cu);
756  __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(a, coeff_a_cu));
757 
758  _mm256_store_ps(inner_arr, inner);
759  for (int j = 0; j < 8; ++j) {
760  tanh_arr[j] = tanhf(inner_arr[j]);
761  }
762  __m256 tanh_val = _mm256_load_ps(tanh_arr);
763 
764  __m256 gelu = _mm256_mul_ps(half_vec, _mm256_mul_ps(a, _mm256_add_ps(one_vec, tanh_val)));
765  __m256 b = _mm256_loadu_ps(&x_ptr[dim + d]);
766 
767  _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu, b));
768  }
769  for (; d < dim; ++d) {
770  float a = x_ptr[d];
771  float b = x_ptr[dim + d];
772  float a3 = a * a * a;
773  float inner = sqrt_2_over_pi * (a + coeff * a3);
774  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
775  out_ptr[d] = gelu_a * b;
776  }
777  }
778 
779 #else
780  // Scalar fallback
781  for (int t = 0; t < tokens; ++t) {
782  const float *x_ptr = x + (size_t)t * inner_dim;
783  float *out_ptr = out + (size_t)t * dim;
784 
785  for (int d = 0; d < dim; ++d) {
786  float a = x_ptr[d];
787  float b = x_ptr[dim + d];
788  float a3 = a * a * a;
789  float inner = sqrt_2_over_pi * (a + coeff * a3);
790  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
791  out_ptr[d] = gelu_a * b;
792  }
793  }
794 #endif
795 }
796 
797 /**
798  * GeGLU forward pass (bf16)
799  * @test test_geglu.py::TestGeGLU::test_geglu_forward_bf16
800  *
801  * BF16 version: converts to FP32, computes, converts back.
802  * Caller provides scratch buffer of size 3 * tokens * dim * sizeof(float).
803  *
804  * Layout:
805  * - scratch[0 : 2*tokens*dim] = FP32 input [a, b]
806  * - scratch[2*tokens*dim : ...] = FP32 output
807  *
808  * Note: We need separate buffers for input and output to avoid overlap
809  * when tokens > 1. The input is 2*dim per token, output is dim per token.
810  *
811  * After changes: make test
812  */
813 void geglu_forward_bf16(const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch)
814 {
815  if (!x || !out || !scratch) return;
816 
817  const size_t fp32_size = (size_t)tokens * (size_t)dim;
818  const size_t input_size = fp32_size * 2; // [a, b] = 2*dim per token
819  float *fp32_input = scratch;
820  float *fp32_output = scratch + input_size;
821 
822  // Convert BF16 input to FP32
823  bf16_tensor_to_float(x, fp32_input, input_size);
824 
825  // Run FP32 GeGLU (output goes to separate buffer to avoid overlap)
826  geglu_forward_fp32(fp32_input, fp32_output, tokens, dim);
827 
828  // Convert FP32 output back to BF16
829  float_tensor_to_bf16(fp32_output, out, fp32_size);
830 }
831 
832 /**
833  * GeGLU backward pass (fp32)
834  * @test test_geglu.py::TestGeGLU::test_geglu_backward_fp32
835  *
836  * dL/dx given dL/d(out) where out = GELU(a) * b
837  * Chain rule:
838  * dL/da = dL/dout * d(GELU)/da * b
839  * dL/db = dL/dout * GELU(a)
840  *
841  * After changes: make test
842  */
843 void geglu_backward_fp32(const float *x,
844  const float *d_out,
845  float *d_x,
846  int tokens,
847  int dim)
848 {
849  const float sqrt_2_over_pi = 0.7978845608f;
850  const float coeff = 0.044715f;
851 
852  const int inner_dim = dim * 2;
853 
854  for (int t = 0; t < tokens; ++t) {
855  const float *x_ptr = x + (size_t)t * inner_dim;
856  const float *d_out_ptr = d_out + (size_t)t * dim;
857  float *d_x_ptr = d_x + (size_t)t * inner_dim;
858 
859  for (int d = 0; d < dim; ++d) {
860  float a = x_ptr[d];
861  float b = x_ptr[dim + d];
862  float dout = d_out_ptr[d];
863 
864  // GELU(a) derivative components
865  float a2 = a * a;
866  float a3 = a2 * a;
867  float g = sqrt_2_over_pi * (a + coeff * a3);
868  float tanh_g = tanhf(g);
869  float sech2_g = 1.0f - tanh_g * tanh_g;
870  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * a2);
871 
872  // d(GELU)/da = 0.5 * (1 + tanh(g)) + 0.5 * a * sech^2(g) * g'
873  float d_gelu = 0.5f * (1.0f + tanh_g) + 0.5f * a * sech2_g * g_prime;
874 
875  // dL/da = dL/dout * d(GELU)/da * b
876  d_x_ptr[d] = dout * d_gelu * b;
877 
878  // dL/db = dL/dout * GELU(a)
879  float gelu_a = 0.5f * a * (1.0f + tanh_g);
880  d_x_ptr[dim + d] = dout * gelu_a;
881  }
882  }
883 }
884 
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
void gelu_backward_exact(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:257
void gelu_exact_inplace(float *data, size_t n)
Definition: gelu_kernels.c:446
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int tokens, int dim)
Definition: gelu_kernels.c:843
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
Definition: gelu_kernels.c:623
void gelu_fast_inplace(float *data, size_t n)
Definition: gelu_kernels.c:132
void geglu_forward_bf16(const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch)
Definition: gelu_kernels.c:813
void gelu_backward_scalar(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:462
void gelu_backward_fast(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:486
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)