23 #if defined(__AVX512F__)
24 #include <immintrin.h>
27 static inline __m512 exp512_fast_bf16(__m512 x) {
29 x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
30 x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
32 const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
33 __m512 z = _mm512_mul_ps(x, log2e);
34 __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT);
35 __m512 f = _mm512_sub_ps(z, zf);
38 const __m512 c0 = _mm512_set1_ps(1.0f);
39 const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
40 const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
41 const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
42 const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
44 __m512 poly = _mm512_fmadd_ps(f, c4, c3);
45 poly = _mm512_fmadd_ps(f, poly, c2);
46 poly = _mm512_fmadd_ps(f, poly, c1);
47 poly = _mm512_fmadd_ps(f, poly, c0);
49 __m512i zi = _mm512_cvtps_epi32(zf);
50 zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
51 zi = _mm512_slli_epi32(zi, 23);
52 __m512 scale = _mm512_castsi512_ps(zi);
54 return _mm512_mul_ps(poly, scale);
58 static inline __m512 sigmoid512_fast_bf16(__m512 x) {
59 __m512 neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x);
60 __m512 exp_neg = exp512_fast_bf16(neg_x);
61 __m512 one = _mm512_set1_ps(1.0f);
62 return _mm512_div_ps(one, _mm512_add_ps(one, exp_neg));
71 if (!input || !output || tokens <= 0 || dim <= 0) {
78 for (
int t = 0; t < T; ++t) {
79 const uint16_t *row = input + (size_t)t * (
size_t)(2 * D);
80 uint16_t *out_row = output + (size_t)t * (
size_t)D;
83 #if defined(__AVX512F__)
85 for (; d + 16 <= D; d += 16) {
86 __m512 a = bf16_loadu_cvt_fp32(&row[d]);
87 __m512 b = bf16_loadu_cvt_fp32(&row[D + d]);
89 __m512 s = sigmoid512_fast_bf16(a);
90 __m512
silu = _mm512_mul_ps(a, s);
91 __m512 y = _mm512_mul_ps(
silu, b);
93 fp32_cvt_storeu_bf16(&out_row[d], y);
109 const uint16_t *d_output,
114 if (!input || !d_output || !d_input || tokens <= 0 || dim <= 0) {
118 const int T = tokens;
121 for (
int t = 0; t < T; ++t) {
122 const uint16_t *row = input + (size_t)t * (
size_t)(2 * D);
123 const uint16_t *dy_row = d_output + (size_t)t * (
size_t)D;
124 uint16_t *dx_row = d_input + (size_t)t * (
size_t)(2 * D);
127 #if defined(__AVX512F__)
129 __m512 one = _mm512_set1_ps(1.0f);
130 for (; d + 16 <= D; d += 16) {
131 __m512 a = bf16_loadu_cvt_fp32(&row[d]);
132 __m512 b = bf16_loadu_cvt_fp32(&row[D + d]);
133 __m512 dy = bf16_loadu_cvt_fp32(&dy_row[d]);
135 __m512 s = sigmoid512_fast_bf16(a);
136 __m512
silu = _mm512_mul_ps(a, s);
137 __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s));
138 __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s);
141 __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
143 __m512 dB = _mm512_mul_ps(dy,
silu);
145 fp32_cvt_storeu_bf16(&dx_row[d], dA);
146 fp32_cvt_storeu_bf16(&dx_row[D + d], dB);
158 float s_prime = s * (1.0f - s);
159 float silu_prime = s + a * s_prime;
161 float dA = dy * b * silu_prime;
162 float dB = dy *
silu;
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
float sigmoid_scalar(float x)
void swiglu_forward_bf16(const uint16_t *input, uint16_t *output, int tokens, int dim)
void swiglu_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, int tokens, int dim)
static void silu(float *x, int n)