← Back to C-Kernel-Engine Docs Doxygen Source Documentation
swiglu_kernels.c File Reference

SwiGLU activation kernels with SIMD (SSE/AVX/AVX512) More...

#include "ckernel_engine.h"
#include <math.h>
#include <stddef.h>

Go to the source code of this file.

Functions

void swiglu_backward (const float *input, const float *d_output, float *d_input, int tokens, int dim)
 
void swiglu_backward_exact (const float *input, const float *d_output, float *d_input, int tokens, int dim)
 
void swiglu_forward (const float *input, float *output, int tokens, int dim)
 
void swiglu_forward_exact (const float *input, float *output, int tokens, int dim)
 

Detailed Description

SwiGLU activation kernels with SIMD (SSE/AVX/AVX512)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

SwiGLU: y = silu(gate) * up = (gate * sigmoid(gate)) * up

Definition in file swiglu_kernels.c.

Function Documentation

◆ swiglu_backward()

void swiglu_backward ( const float *  input,
const float *  d_output,
float *  d_input,
int  tokens,
int  dim 
)

SwiGLU backward pass

Test:

test_swiglu.py::TestSwiGLUBackward::test_backward_tokens

test_swiglu.py::TestSwiGLUBackward::test_backward_single

test_parity.py::test_swiglu_backward_parity

Computes dGate and dUp given dY. dGate = dy * b * silu'(a), dUp = dy * silu(a)

After changes: make test && make llamacpp-parity-full

Definition at line 215 of file swiglu_kernels.c.

220 {
221  int T = tokens;
222  int D = dim;
223 
224  for (int t = 0; t < T; ++t) {
225  const float *row = input + (size_t)t * (2 * D);
226  const float *dy_row = d_output + (size_t)t * D;
227  float *dx_row = d_input + (size_t)t * (2 * D);
228  int d = 0;
229 
230 #if defined(__AVX512F__)
231  // AVX-512: Process 16 floats at a time
232  __m512 one = _mm512_set1_ps(1.0f);
233  for (; d + 16 <= D; d += 16) {
234  __m512 a = _mm512_loadu_ps(&row[d]); // gate
235  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
236  __m512 dy = _mm512_loadu_ps(&dy_row[d]);
237 
238  __m512 s = sigmoid512_fast(a); // sigmoid(a)
239  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * s
240  __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s)); // s * (1 - s)
241  __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s); // s + a * s_prime
242 
243  // dA = dy * b * silu_prime
244  __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
245  // dB = dy * silu
246  __m512 dB = _mm512_mul_ps(dy, silu);
247 
248  _mm512_storeu_ps(&dx_row[d], dA);
249  _mm512_storeu_ps(&dx_row[D + d], dB);
250  }
251 #elif defined(__AVX2__)
252  // AVX2: Process 8 floats at a time
253  __m256 one = _mm256_set1_ps(1.0f);
254  for (; d + 8 <= D; d += 8) {
255  __m256 a = _mm256_loadu_ps(&row[d]); // gate
256  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
257  __m256 dy = _mm256_loadu_ps(&dy_row[d]);
258 
259  __m256 s = sigmoid256_fast(a); // sigmoid(a)
260  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * s
261  __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s)); // s * (1 - s)
262  __m256 silu_prime = _mm256_fmadd_ps(a, s_prime, s); // s + a * s_prime
263 
264  // dA = dy * b * silu_prime
265  __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
266  // dB = dy * silu
267  __m256 dB = _mm256_mul_ps(dy, silu);
268 
269  _mm256_storeu_ps(&dx_row[d], dA);
270  _mm256_storeu_ps(&dx_row[D + d], dB);
271  }
272 #elif defined(__AVX__)
273  // AVX1: Vectorize arithmetic, use scalar sigmoid
274  __m256 one = _mm256_set1_ps(1.0f);
275  float a_arr[8] __attribute__((aligned(32)));
276  float s_arr[8] __attribute__((aligned(32)));
277 
278  for (; d + 8 <= D; d += 8) {
279  __m256 a = _mm256_loadu_ps(&row[d]); // gate
280  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
281  __m256 dy = _mm256_loadu_ps(&dy_row[d]);
282 
283  // Compute sigmoid scalarly
284  _mm256_store_ps(a_arr, a);
285  for (int j = 0; j < 8; ++j) {
286  s_arr[j] = sigmoid_scalar(a_arr[j]);
287  }
288  __m256 s = _mm256_load_ps(s_arr);
289 
290  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * s
291  __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s)); // s * (1 - s)
292  // silu_prime = s + a * s_prime (no FMA in AVX1)
293  __m256 a_s_prime = _mm256_mul_ps(a, s_prime);
294  __m256 silu_prime = _mm256_add_ps(s, a_s_prime);
295 
296  // dA = dy * b * silu_prime
297  __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
298  // dB = dy * silu
299  __m256 dB = _mm256_mul_ps(dy, silu);
300 
301  _mm256_storeu_ps(&dx_row[d], dA);
302  _mm256_storeu_ps(&dx_row[D + d], dB);
303  }
304 #endif
305 
306  // Scalar fallback for remaining elements
307  for (; d < D; ++d) {
308  float a = row[d]; // gate
309  float b = row[D + d]; // value
310  float dy = dy_row[d];
311 
312  float s = sigmoid_scalar(a); // sigmoid(a)
313  float silu = a * s; // silu(a)
314  float s_prime = s * (1.0f - s); // sigmoid'(a)
315  float silu_prime = s + a * s_prime; // silu'(a)
316 
317  float dA = dy * b * silu_prime;
318  float dB = dy * silu;
319 
320  dx_row[d] = dA;
321  dx_row[D + d] = dB;
322  }
323  }
324 }
float sigmoid_scalar(float x)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)
static void silu(float *x, int n)
Definition: v6_simple.c:159

References __attribute__(), sigmoid_scalar(), and silu().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ swiglu_backward_exact()

void swiglu_backward_exact ( const float *  input,
const float *  d_output,
float *  d_input,
int  tokens,
int  dim 
)

SwiGLU backward pass (exact version using stdlib sigmoid)

Test:

test_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast

test_swiglu.py::TestSwiGLUBackward::test_exact_single

Uses standard library expf for numerical accuracy reference.

After changes: make test

Definition at line 373 of file swiglu_kernels.c.

378 {
379  int T = tokens;
380  int D = dim;
381 
382  for (int t = 0; t < T; ++t) {
383  const float *row = input + (size_t)t * (2 * D);
384  const float *dy_row = d_output + (size_t)t * D;
385  float *dx_row = d_input + (size_t)t * (2 * D);
386 
387  for (int d = 0; d < D; ++d) {
388  float a = row[d]; // gate
389  float b = row[D + d]; // value
390  float dy = dy_row[d];
391 
392  // Use standard library expf via sigmoid_scalar
393  float s = sigmoid_scalar(a); // sigmoid(a)
394  float silu = a * s; // silu(a)
395  float s_prime = s * (1.0f - s); // sigmoid'(a)
396  float silu_prime = s + a * s_prime; // silu'(a)
397 
398  float dA = dy * b * silu_prime;
399  float dB = dy * silu;
400 
401  dx_row[d] = dA;
402  dx_row[D + d] = dB;
403  }
404  }
405 }

References sigmoid_scalar(), and silu().

◆ swiglu_forward()

void swiglu_forward ( const float *  input,
float *  output,
int  tokens,
int  dim 
)

SwiGLU forward pass

Test:

test_swiglu.py::TestSwiGLUForward::test_forward_tokens

test_swiglu.py::TestSwiGLUForward::test_forward_single

test_mlp.py::TestMLPForward::test_swiglu_mlp

test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode

test_parity.py::test_swiglu_parity

SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)

After changes: make test && make llamacpp-parity-full

Definition at line 131 of file swiglu_kernels.c.

135 {
136  int T = tokens;
137  int D = dim;
138 
139  for (int t = 0; t < T; ++t) {
140  const float *row = input + (size_t)t * (2 * D);
141  float *out_row = output + (size_t)t * D;
142  int d = 0;
143 
144 #if defined(__AVX512F__)
145  // AVX-512: Process 16 floats at a time
146  for (; d + 16 <= D; d += 16) {
147  __m512 a = _mm512_loadu_ps(&row[d]); // gate
148  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
149 
150  __m512 s = sigmoid512_fast(a); // sigmoid(a)
151  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * sigmoid(a)
152  __m512 y = _mm512_mul_ps(silu, b); // y = silu(a) * b
153 
154  _mm512_storeu_ps(&out_row[d], y);
155  }
156 #elif defined(__AVX2__)
157  // AVX2: Process 8 floats at a time
158  for (; d + 8 <= D; d += 8) {
159  __m256 a = _mm256_loadu_ps(&row[d]); // gate
160  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
161 
162  __m256 s = sigmoid256_fast(a); // sigmoid(a)
163  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
164  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
165 
166  _mm256_storeu_ps(&out_row[d], y);
167  }
168 #elif defined(__AVX__)
169  // AVX1: Vectorize arithmetic, use scalar sigmoid
170  float a_arr[8] __attribute__((aligned(32)));
171  float s_arr[8] __attribute__((aligned(32)));
172 
173  for (; d + 8 <= D; d += 8) {
174  __m256 a = _mm256_loadu_ps(&row[d]); // gate
175  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
176 
177  // Compute sigmoid scalarly
178  _mm256_store_ps(a_arr, a);
179  for (int j = 0; j < 8; ++j) {
180  s_arr[j] = sigmoid_scalar(a_arr[j]);
181  }
182  __m256 s = _mm256_load_ps(s_arr);
183 
184  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
185  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
186 
187  _mm256_storeu_ps(&out_row[d], y);
188  }
189 #endif
190 
191  // Scalar fallback for remaining elements
192  for (; d < D; ++d) {
193  float a = row[d]; // gate
194  float b = row[D + d]; // value
195 
196  float s = sigmoid_scalar(a); // sigmoid(a)
197  float silu = a * s; // silu(a) = a * sigmoid(a)
198 
199  out_row[d] = silu * b;
200  }
201  }
202 }

References __attribute__(), sigmoid_scalar(), and silu().

Referenced by ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_mlp_swiglu_forward_quant(), ck_mlp_swiglu_forward_ref(), ck_test_swiglu(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ swiglu_forward_exact()

void swiglu_forward_exact ( const float *  input,
float *  output,
int  tokens,
int  dim 
)

SwiGLU forward pass (exact version using stdlib sigmoid)

Test:

test_swiglu.py::TestSwiGLUForward::test_exact_vs_fast

test_swiglu.py::TestSwiGLUForward::test_exact_single

Uses standard library expf for numerical accuracy reference.

After changes: make test

Definition at line 339 of file swiglu_kernels.c.

343 {
344  int T = tokens;
345  int D = dim;
346 
347  for (int t = 0; t < T; ++t) {
348  const float *row = input + (size_t)t * (2 * D);
349  float *out_row = output + (size_t)t * D;
350 
351  for (int d = 0; d < D; ++d) {
352  float a = row[d]; // gate
353  float b = row[D + d]; // value
354 
355  // Use standard library expf via sigmoid_scalar
356  float s = sigmoid_scalar(a); // sigmoid(a) = 1/(1+expf(-a))
357  float silu = a * s; // silu(a) = a * sigmoid(a)
358 
359  out_row[d] = silu * b;
360  }
361  }
362 }

References sigmoid_scalar(), and silu().