← Back to C-Kernel-Engine Docs Doxygen Source Documentation
swiglu_kernels.c
Go to the documentation of this file.
1 /**
2  * @file swiglu_kernels.c
3  * @brief SwiGLU activation kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * SwiGLU: y = silu(gate) * up = (gate * sigmoid(gate)) * up
15  */
16 
17 #include "ckernel_engine.h"
18 #include <math.h>
19 #include <stddef.h>
20 
21 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
22 #include <immintrin.h>
23 #endif
24 
25 /* ========================================================================== */
26 /* Fast exp approximation for SIMD */
27 /* ========================================================================== */
28 
29 #if defined(__AVX512F__)
30 // AVX-512 fast exp approximation
31 static inline __m512 exp512_fast(__m512 x) {
32  // Clamp to avoid overflow/underflow
33  x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
34  x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
35 
36  // exp(x) = 2^(x * log2(e))
37  const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
38  __m512 z = _mm512_mul_ps(x, log2e);
39 
40  // Split into integer and fractional parts
41  __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT);
42  __m512 f = _mm512_sub_ps(z, zf);
43 
44  // Polynomial for 2^f, f in [-0.5, 0.5]
45  const __m512 c0 = _mm512_set1_ps(1.0f);
46  const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
47  const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
48  const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
49  const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
50 
51  __m512 poly = _mm512_fmadd_ps(f, c4, c3);
52  poly = _mm512_fmadd_ps(f, poly, c2);
53  poly = _mm512_fmadd_ps(f, poly, c1);
54  poly = _mm512_fmadd_ps(f, poly, c0);
55 
56  // Scale by 2^n
57  __m512i zi = _mm512_cvtps_epi32(zf);
58  zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
59  zi = _mm512_slli_epi32(zi, 23);
60  __m512 scale = _mm512_castsi512_ps(zi);
61 
62  return _mm512_mul_ps(poly, scale);
63 }
64 
65 // AVX-512 sigmoid: 1 / (1 + exp(-x))
66 static inline __m512 sigmoid512_fast(__m512 x) {
67  __m512 neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x);
68  __m512 exp_neg = exp512_fast(neg_x);
69  __m512 one = _mm512_set1_ps(1.0f);
70  return _mm512_div_ps(one, _mm512_add_ps(one, exp_neg));
71 }
72 #endif
73 
74 #if defined(__AVX2__)
75 // AVX2 fast exp approximation (needs FMA and integer ops)
76 static inline __m256 exp256_fast(__m256 x) {
77  // Clamp
78  x = _mm256_max_ps(x, _mm256_set1_ps(-88.0f));
79  x = _mm256_min_ps(x, _mm256_set1_ps(88.0f));
80 
81  // exp(x) = 2^(x * log2(e))
82  const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
83  __m256 z = _mm256_mul_ps(x, log2e);
84 
85  // Round to nearest integer
86  __m256 zf = _mm256_round_ps(z, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
87  __m256 f = _mm256_sub_ps(z, zf);
88 
89  // Polynomial for 2^f
90  const __m256 c0 = _mm256_set1_ps(1.0f);
91  const __m256 c1 = _mm256_set1_ps(0.6931471805599453f);
92  const __m256 c2 = _mm256_set1_ps(0.2402265069591007f);
93  const __m256 c3 = _mm256_set1_ps(0.05550410866482158f);
94  const __m256 c4 = _mm256_set1_ps(0.009618129107628478f);
95 
96  __m256 poly = _mm256_fmadd_ps(f, c4, c3);
97  poly = _mm256_fmadd_ps(f, poly, c2);
98  poly = _mm256_fmadd_ps(f, poly, c1);
99  poly = _mm256_fmadd_ps(f, poly, c0);
100 
101  // Scale by 2^n
102  __m256i zi = _mm256_cvtps_epi32(zf);
103  zi = _mm256_add_epi32(zi, _mm256_set1_epi32(127));
104  zi = _mm256_slli_epi32(zi, 23);
105  __m256 scale = _mm256_castsi256_ps(zi);
106 
107  return _mm256_mul_ps(poly, scale);
108 }
109 
110 // AVX2 sigmoid
111 static inline __m256 sigmoid256_fast(__m256 x) {
112  __m256 neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
113  __m256 exp_neg = exp256_fast(neg_x);
114  __m256 one = _mm256_set1_ps(1.0f);
115  return _mm256_div_ps(one, _mm256_add_ps(one, exp_neg));
116 }
117 #endif
118 
119 /**
120  * SwiGLU forward pass
121  * @test test_swiglu.py::TestSwiGLUForward::test_forward_tokens
122  * @test test_swiglu.py::TestSwiGLUForward::test_forward_single
123  * @test test_mlp.py::TestMLPForward::test_swiglu_mlp
124  * @test test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode
125  * @test test_parity.py::test_swiglu_parity
126  *
127  * SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)
128  *
129  * After changes: make test && make llamacpp-parity-full
130  */
131 void swiglu_forward(const float *input,
132  float *output,
133  int tokens,
134  int dim)
135 {
136  int T = tokens;
137  int D = dim;
138 
139  for (int t = 0; t < T; ++t) {
140  const float *row = input + (size_t)t * (2 * D);
141  float *out_row = output + (size_t)t * D;
142  int d = 0;
143 
144 #if defined(__AVX512F__)
145  // AVX-512: Process 16 floats at a time
146  for (; d + 16 <= D; d += 16) {
147  __m512 a = _mm512_loadu_ps(&row[d]); // gate
148  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
149 
150  __m512 s = sigmoid512_fast(a); // sigmoid(a)
151  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * sigmoid(a)
152  __m512 y = _mm512_mul_ps(silu, b); // y = silu(a) * b
153 
154  _mm512_storeu_ps(&out_row[d], y);
155  }
156 #elif defined(__AVX2__)
157  // AVX2: Process 8 floats at a time
158  for (; d + 8 <= D; d += 8) {
159  __m256 a = _mm256_loadu_ps(&row[d]); // gate
160  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
161 
162  __m256 s = sigmoid256_fast(a); // sigmoid(a)
163  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
164  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
165 
166  _mm256_storeu_ps(&out_row[d], y);
167  }
168 #elif defined(__AVX__)
169  // AVX1: Vectorize arithmetic, use scalar sigmoid
170  float a_arr[8] __attribute__((aligned(32)));
171  float s_arr[8] __attribute__((aligned(32)));
172 
173  for (; d + 8 <= D; d += 8) {
174  __m256 a = _mm256_loadu_ps(&row[d]); // gate
175  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
176 
177  // Compute sigmoid scalarly
178  _mm256_store_ps(a_arr, a);
179  for (int j = 0; j < 8; ++j) {
180  s_arr[j] = sigmoid_scalar(a_arr[j]);
181  }
182  __m256 s = _mm256_load_ps(s_arr);
183 
184  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
185  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
186 
187  _mm256_storeu_ps(&out_row[d], y);
188  }
189 #endif
190 
191  // Scalar fallback for remaining elements
192  for (; d < D; ++d) {
193  float a = row[d]; // gate
194  float b = row[D + d]; // value
195 
196  float s = sigmoid_scalar(a); // sigmoid(a)
197  float silu = a * s; // silu(a) = a * sigmoid(a)
198 
199  out_row[d] = silu * b;
200  }
201  }
202 }
203 
204 /**
205  * SwiGLU backward pass
206  * @test test_swiglu.py::TestSwiGLUBackward::test_backward_tokens
207  * @test test_swiglu.py::TestSwiGLUBackward::test_backward_single
208  * @test test_parity.py::test_swiglu_backward_parity
209  *
210  * Computes dGate and dUp given dY.
211  * dGate = dy * b * silu'(a), dUp = dy * silu(a)
212  *
213  * After changes: make test && make llamacpp-parity-full
214  */
215 void swiglu_backward(const float *input,
216  const float *d_output,
217  float *d_input,
218  int tokens,
219  int dim)
220 {
221  int T = tokens;
222  int D = dim;
223 
224  for (int t = 0; t < T; ++t) {
225  const float *row = input + (size_t)t * (2 * D);
226  const float *dy_row = d_output + (size_t)t * D;
227  float *dx_row = d_input + (size_t)t * (2 * D);
228  int d = 0;
229 
230 #if defined(__AVX512F__)
231  // AVX-512: Process 16 floats at a time
232  __m512 one = _mm512_set1_ps(1.0f);
233  for (; d + 16 <= D; d += 16) {
234  __m512 a = _mm512_loadu_ps(&row[d]); // gate
235  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
236  __m512 dy = _mm512_loadu_ps(&dy_row[d]);
237 
238  __m512 s = sigmoid512_fast(a); // sigmoid(a)
239  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * s
240  __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s)); // s * (1 - s)
241  __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s); // s + a * s_prime
242 
243  // dA = dy * b * silu_prime
244  __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
245  // dB = dy * silu
246  __m512 dB = _mm512_mul_ps(dy, silu);
247 
248  _mm512_storeu_ps(&dx_row[d], dA);
249  _mm512_storeu_ps(&dx_row[D + d], dB);
250  }
251 #elif defined(__AVX2__)
252  // AVX2: Process 8 floats at a time
253  __m256 one = _mm256_set1_ps(1.0f);
254  for (; d + 8 <= D; d += 8) {
255  __m256 a = _mm256_loadu_ps(&row[d]); // gate
256  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
257  __m256 dy = _mm256_loadu_ps(&dy_row[d]);
258 
259  __m256 s = sigmoid256_fast(a); // sigmoid(a)
260  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * s
261  __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s)); // s * (1 - s)
262  __m256 silu_prime = _mm256_fmadd_ps(a, s_prime, s); // s + a * s_prime
263 
264  // dA = dy * b * silu_prime
265  __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
266  // dB = dy * silu
267  __m256 dB = _mm256_mul_ps(dy, silu);
268 
269  _mm256_storeu_ps(&dx_row[d], dA);
270  _mm256_storeu_ps(&dx_row[D + d], dB);
271  }
272 #elif defined(__AVX__)
273  // AVX1: Vectorize arithmetic, use scalar sigmoid
274  __m256 one = _mm256_set1_ps(1.0f);
275  float a_arr[8] __attribute__((aligned(32)));
276  float s_arr[8] __attribute__((aligned(32)));
277 
278  for (; d + 8 <= D; d += 8) {
279  __m256 a = _mm256_loadu_ps(&row[d]); // gate
280  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
281  __m256 dy = _mm256_loadu_ps(&dy_row[d]);
282 
283  // Compute sigmoid scalarly
284  _mm256_store_ps(a_arr, a);
285  for (int j = 0; j < 8; ++j) {
286  s_arr[j] = sigmoid_scalar(a_arr[j]);
287  }
288  __m256 s = _mm256_load_ps(s_arr);
289 
290  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * s
291  __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s)); // s * (1 - s)
292  // silu_prime = s + a * s_prime (no FMA in AVX1)
293  __m256 a_s_prime = _mm256_mul_ps(a, s_prime);
294  __m256 silu_prime = _mm256_add_ps(s, a_s_prime);
295 
296  // dA = dy * b * silu_prime
297  __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
298  // dB = dy * silu
299  __m256 dB = _mm256_mul_ps(dy, silu);
300 
301  _mm256_storeu_ps(&dx_row[d], dA);
302  _mm256_storeu_ps(&dx_row[D + d], dB);
303  }
304 #endif
305 
306  // Scalar fallback for remaining elements
307  for (; d < D; ++d) {
308  float a = row[d]; // gate
309  float b = row[D + d]; // value
310  float dy = dy_row[d];
311 
312  float s = sigmoid_scalar(a); // sigmoid(a)
313  float silu = a * s; // silu(a)
314  float s_prime = s * (1.0f - s); // sigmoid'(a)
315  float silu_prime = s + a * s_prime; // silu'(a)
316 
317  float dA = dy * b * silu_prime;
318  float dB = dy * silu;
319 
320  dx_row[d] = dA;
321  dx_row[D + d] = dB;
322  }
323  }
324 }
325 
326 // ============================================================================
327 // Exact versions using standard library expf (slower but accurate)
328 // ============================================================================
329 
330 /**
331  * SwiGLU forward pass (exact version using stdlib sigmoid)
332  * @test test_swiglu.py::TestSwiGLUForward::test_exact_vs_fast
333  * @test test_swiglu.py::TestSwiGLUForward::test_exact_single
334  *
335  * Uses standard library expf for numerical accuracy reference.
336  *
337  * After changes: make test
338  */
339 void swiglu_forward_exact(const float *input,
340  float *output,
341  int tokens,
342  int dim)
343 {
344  int T = tokens;
345  int D = dim;
346 
347  for (int t = 0; t < T; ++t) {
348  const float *row = input + (size_t)t * (2 * D);
349  float *out_row = output + (size_t)t * D;
350 
351  for (int d = 0; d < D; ++d) {
352  float a = row[d]; // gate
353  float b = row[D + d]; // value
354 
355  // Use standard library expf via sigmoid_scalar
356  float s = sigmoid_scalar(a); // sigmoid(a) = 1/(1+expf(-a))
357  float silu = a * s; // silu(a) = a * sigmoid(a)
358 
359  out_row[d] = silu * b;
360  }
361  }
362 }
363 
364 /**
365  * SwiGLU backward pass (exact version using stdlib sigmoid)
366  * @test test_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast
367  * @test test_swiglu.py::TestSwiGLUBackward::test_exact_single
368  *
369  * Uses standard library expf for numerical accuracy reference.
370  *
371  * After changes: make test
372  */
373 void swiglu_backward_exact(const float *input,
374  const float *d_output,
375  float *d_input,
376  int tokens,
377  int dim)
378 {
379  int T = tokens;
380  int D = dim;
381 
382  for (int t = 0; t < T; ++t) {
383  const float *row = input + (size_t)t * (2 * D);
384  const float *dy_row = d_output + (size_t)t * D;
385  float *dx_row = d_input + (size_t)t * (2 * D);
386 
387  for (int d = 0; d < D; ++d) {
388  float a = row[d]; // gate
389  float b = row[D + d]; // value
390  float dy = dy_row[d];
391 
392  // Use standard library expf via sigmoid_scalar
393  float s = sigmoid_scalar(a); // sigmoid(a)
394  float silu = a * s; // silu(a)
395  float s_prime = s * (1.0f - s); // sigmoid'(a)
396  float silu_prime = s + a * s_prime; // silu'(a)
397 
398  float dA = dy * b * silu_prime;
399  float dB = dy * silu;
400 
401  dx_row[d] = dA;
402  dx_row[D + d] = dB;
403  }
404  }
405 }
float sigmoid_scalar(float x)
void swiglu_forward_exact(const float *input, float *output, int tokens, int dim)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void swiglu_backward_exact(const float *input, const float *d_output, float *d_input, int tokens, int dim)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)
static void silu(float *x, int n)
Definition: v6_simple.c:159