SwiGLU activation kernels with SIMD (SSE/AVX/AVX512) More...
Go to the source code of this file.
Functions | |
| void | swiglu_backward (const float *input, const float *d_output, float *d_input, int tokens, int dim) |
| void | swiglu_backward_exact (const float *input, const float *d_output, float *d_input, int tokens, int dim) |
| void | swiglu_forward (const float *input, float *output, int tokens, int dim) |
| void | swiglu_forward_exact (const float *input, float *output, int tokens, int dim) |
SwiGLU activation kernels with SIMD (SSE/AVX/AVX512)
After changes: make test && make llamacpp-parity-full
SwiGLU: y = silu(gate) * up = (gate * sigmoid(gate)) * up
Definition in file swiglu_kernels.c.
| void swiglu_backward | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU backward pass
test_swiglu.py::TestSwiGLUBackward::test_backward_tokens
test_swiglu.py::TestSwiGLUBackward::test_backward_single
test_parity.py::test_swiglu_backward_parity
Computes dGate and dUp given dY. dGate = dy * b * silu'(a), dUp = dy * silu(a)
After changes: make test && make llamacpp-parity-full
Definition at line 215 of file swiglu_kernels.c.
References __attribute__(), sigmoid_scalar(), and silu().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void swiglu_backward_exact | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU backward pass (exact version using stdlib sigmoid)
test_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast
test_swiglu.py::TestSwiGLUBackward::test_exact_single
Uses standard library expf for numerical accuracy reference.
After changes: make test
Definition at line 373 of file swiglu_kernels.c.
References sigmoid_scalar(), and silu().
| void swiglu_forward | ( | const float * | input, |
| float * | output, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU forward pass
test_swiglu.py::TestSwiGLUForward::test_forward_tokens
test_swiglu.py::TestSwiGLUForward::test_forward_single
test_mlp.py::TestMLPForward::test_swiglu_mlp
test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode
test_parity.py::test_swiglu_parity
SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)
After changes: make test && make llamacpp-parity-full
Definition at line 131 of file swiglu_kernels.c.
References __attribute__(), sigmoid_scalar(), and silu().
Referenced by ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_mlp_swiglu_forward_quant(), ck_mlp_swiglu_forward_ref(), ck_test_swiglu(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().
| void swiglu_forward_exact | ( | const float * | input, |
| float * | output, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU forward pass (exact version using stdlib sigmoid)
test_swiglu.py::TestSwiGLUForward::test_exact_vs_fast
test_swiglu.py::TestSwiGLUForward::test_exact_single
Uses standard library expf for numerical accuracy reference.
After changes: make test
Definition at line 339 of file swiglu_kernels.c.
References sigmoid_scalar(), and silu().