21 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
22 #include <immintrin.h>
29 #if defined(__AVX512F__)
31 static inline __m512 exp512_fast(__m512 x) {
33 x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
34 x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
37 const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
38 __m512 z = _mm512_mul_ps(x, log2e);
41 __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT);
42 __m512 f = _mm512_sub_ps(z, zf);
45 const __m512 c0 = _mm512_set1_ps(1.0f);
46 const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
47 const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
48 const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
49 const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
51 __m512 poly = _mm512_fmadd_ps(f, c4, c3);
52 poly = _mm512_fmadd_ps(f, poly, c2);
53 poly = _mm512_fmadd_ps(f, poly, c1);
54 poly = _mm512_fmadd_ps(f, poly, c0);
57 __m512i zi = _mm512_cvtps_epi32(zf);
58 zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
59 zi = _mm512_slli_epi32(zi, 23);
60 __m512 scale = _mm512_castsi512_ps(zi);
62 return _mm512_mul_ps(poly, scale);
66 static inline __m512 sigmoid512_fast(__m512 x) {
67 __m512 neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x);
68 __m512 exp_neg = exp512_fast(neg_x);
69 __m512 one = _mm512_set1_ps(1.0f);
70 return _mm512_div_ps(one, _mm512_add_ps(one, exp_neg));
76 static inline __m256 exp256_fast(__m256 x) {
78 x = _mm256_max_ps(x, _mm256_set1_ps(-88.0f));
79 x = _mm256_min_ps(x, _mm256_set1_ps(88.0f));
82 const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
83 __m256 z = _mm256_mul_ps(x, log2e);
86 __m256 zf = _mm256_round_ps(z, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
87 __m256 f = _mm256_sub_ps(z, zf);
90 const __m256 c0 = _mm256_set1_ps(1.0f);
91 const __m256 c1 = _mm256_set1_ps(0.6931471805599453f);
92 const __m256 c2 = _mm256_set1_ps(0.2402265069591007f);
93 const __m256 c3 = _mm256_set1_ps(0.05550410866482158f);
94 const __m256 c4 = _mm256_set1_ps(0.009618129107628478f);
96 __m256 poly = _mm256_fmadd_ps(f, c4, c3);
97 poly = _mm256_fmadd_ps(f, poly, c2);
98 poly = _mm256_fmadd_ps(f, poly, c1);
99 poly = _mm256_fmadd_ps(f, poly, c0);
102 __m256i zi = _mm256_cvtps_epi32(zf);
103 zi = _mm256_add_epi32(zi, _mm256_set1_epi32(127));
104 zi = _mm256_slli_epi32(zi, 23);
105 __m256 scale = _mm256_castsi256_ps(zi);
107 return _mm256_mul_ps(poly, scale);
111 static inline __m256 sigmoid256_fast(__m256 x) {
112 __m256 neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
113 __m256 exp_neg = exp256_fast(neg_x);
114 __m256 one = _mm256_set1_ps(1.0f);
115 return _mm256_div_ps(one, _mm256_add_ps(one, exp_neg));
139 for (
int t = 0; t < T; ++t) {
140 const float *row = input + (size_t)t * (2 * D);
141 float *out_row = output + (size_t)t * D;
144 #if defined(__AVX512F__)
146 for (; d + 16 <= D; d += 16) {
147 __m512 a = _mm512_loadu_ps(&row[d]);
148 __m512 b = _mm512_loadu_ps(&row[D + d]);
150 __m512 s = sigmoid512_fast(a);
151 __m512
silu = _mm512_mul_ps(a, s);
152 __m512 y = _mm512_mul_ps(
silu, b);
154 _mm512_storeu_ps(&out_row[d], y);
156 #elif defined(__AVX2__)
158 for (; d + 8 <= D; d += 8) {
159 __m256 a = _mm256_loadu_ps(&row[d]);
160 __m256 b = _mm256_loadu_ps(&row[D + d]);
162 __m256 s = sigmoid256_fast(a);
163 __m256
silu = _mm256_mul_ps(a, s);
164 __m256 y = _mm256_mul_ps(
silu, b);
166 _mm256_storeu_ps(&out_row[d], y);
168 #elif defined(__AVX__)
173 for (; d + 8 <= D; d += 8) {
174 __m256 a = _mm256_loadu_ps(&row[d]);
175 __m256 b = _mm256_loadu_ps(&row[D + d]);
178 _mm256_store_ps(a_arr, a);
179 for (
int j = 0; j < 8; ++j) {
182 __m256 s = _mm256_load_ps(s_arr);
184 __m256
silu = _mm256_mul_ps(a, s);
185 __m256 y = _mm256_mul_ps(
silu, b);
187 _mm256_storeu_ps(&out_row[d], y);
194 float b = row[D + d];
199 out_row[d] =
silu * b;
216 const float *d_output,
224 for (
int t = 0; t < T; ++t) {
225 const float *row = input + (size_t)t * (2 * D);
226 const float *dy_row = d_output + (size_t)t * D;
227 float *dx_row = d_input + (size_t)t * (2 * D);
230 #if defined(__AVX512F__)
232 __m512 one = _mm512_set1_ps(1.0f);
233 for (; d + 16 <= D; d += 16) {
234 __m512 a = _mm512_loadu_ps(&row[d]);
235 __m512 b = _mm512_loadu_ps(&row[D + d]);
236 __m512 dy = _mm512_loadu_ps(&dy_row[d]);
238 __m512 s = sigmoid512_fast(a);
239 __m512
silu = _mm512_mul_ps(a, s);
240 __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s));
241 __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s);
244 __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
246 __m512 dB = _mm512_mul_ps(dy,
silu);
248 _mm512_storeu_ps(&dx_row[d], dA);
249 _mm512_storeu_ps(&dx_row[D + d], dB);
251 #elif defined(__AVX2__)
253 __m256 one = _mm256_set1_ps(1.0f);
254 for (; d + 8 <= D; d += 8) {
255 __m256 a = _mm256_loadu_ps(&row[d]);
256 __m256 b = _mm256_loadu_ps(&row[D + d]);
257 __m256 dy = _mm256_loadu_ps(&dy_row[d]);
259 __m256 s = sigmoid256_fast(a);
260 __m256
silu = _mm256_mul_ps(a, s);
261 __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s));
262 __m256 silu_prime = _mm256_fmadd_ps(a, s_prime, s);
265 __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
267 __m256 dB = _mm256_mul_ps(dy,
silu);
269 _mm256_storeu_ps(&dx_row[d], dA);
270 _mm256_storeu_ps(&dx_row[D + d], dB);
272 #elif defined(__AVX__)
274 __m256 one = _mm256_set1_ps(1.0f);
278 for (; d + 8 <= D; d += 8) {
279 __m256 a = _mm256_loadu_ps(&row[d]);
280 __m256 b = _mm256_loadu_ps(&row[D + d]);
281 __m256 dy = _mm256_loadu_ps(&dy_row[d]);
284 _mm256_store_ps(a_arr, a);
285 for (
int j = 0; j < 8; ++j) {
288 __m256 s = _mm256_load_ps(s_arr);
290 __m256
silu = _mm256_mul_ps(a, s);
291 __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s));
293 __m256 a_s_prime = _mm256_mul_ps(a, s_prime);
294 __m256 silu_prime = _mm256_add_ps(s, a_s_prime);
297 __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
299 __m256 dB = _mm256_mul_ps(dy,
silu);
301 _mm256_storeu_ps(&dx_row[d], dA);
302 _mm256_storeu_ps(&dx_row[D + d], dB);
309 float b = row[D + d];
310 float dy = dy_row[d];
314 float s_prime = s * (1.0f - s);
315 float silu_prime = s + a * s_prime;
317 float dA = dy * b * silu_prime;
318 float dB = dy *
silu;
347 for (
int t = 0; t < T; ++t) {
348 const float *row = input + (size_t)t * (2 * D);
349 float *out_row = output + (size_t)t * D;
351 for (
int d = 0; d < D; ++d) {
353 float b = row[D + d];
359 out_row[d] =
silu * b;
374 const float *d_output,
382 for (
int t = 0; t < T; ++t) {
383 const float *row = input + (size_t)t * (2 * D);
384 const float *dy_row = d_output + (size_t)t * D;
385 float *dx_row = d_input + (size_t)t * (2 * D);
387 for (
int d = 0; d < D; ++d) {
389 float b = row[D + d];
390 float dy = dy_row[d];
395 float s_prime = s * (1.0f - s);
396 float silu_prime = s + a * s_prime;
398 float dA = dy * b * silu_prime;
399 float dB = dy *
silu;
float sigmoid_scalar(float x)
void swiglu_forward_exact(const float *input, float *output, int tokens, int dim)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void swiglu_backward_exact(const float *input, const float *d_output, float *d_input, int tokens, int dim)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)
static void silu(float *x, int n)