21 #if defined(__AVX512F__)
22 #include <immintrin.h>
28 return 1.0f / (1.0f + expf(-x));
31 #if defined(__AVX512F__)
34 static inline __m512 exp_approx_avx512(__m512 x)
37 const __m512 max_val = _mm512_set1_ps(88.0f);
38 const __m512 min_val = _mm512_set1_ps(-88.0f);
39 x = _mm512_max_ps(_mm512_min_ps(x, max_val), min_val);
42 const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
43 __m512 z = _mm512_mul_ps(x, log2e);
46 __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT);
47 __m512 f = _mm512_sub_ps(z, zf);
52 const __m512 c0 = _mm512_set1_ps(1.0f);
53 const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
54 const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
55 const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
56 const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
59 __m512 poly = _mm512_fmadd_ps(f, c4, c3);
60 poly = _mm512_fmadd_ps(f, poly, c2);
61 poly = _mm512_fmadd_ps(f, poly, c1);
62 poly = _mm512_fmadd_ps(f, poly, c0);
65 __m512i zi = _mm512_cvtps_epi32(zf);
66 zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
67 zi = _mm512_slli_epi32(zi, 23);
68 __m512 scale = _mm512_castsi512_ps(zi);
70 return _mm512_mul_ps(poly, scale);
73 static inline __m512 sigmoid_avx512_vec(__m512 x)
75 __m512 neg = _mm512_sub_ps(_mm512_setzero_ps(), x);
76 __m512 exp_neg = exp_approx_avx512(neg);
77 __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg);
78 return _mm512_div_ps(_mm512_set1_ps(1.0f), denom);
81 static void sigmoid_forward_avx512(
const float *input,
86 for (; i + 16 <= n; i += 16) {
87 __m512 in_vec = _mm512_loadu_ps(input + i);
88 __m512 sig = sigmoid_avx512_vec(in_vec);
89 _mm512_storeu_ps(output + i, sig);
97 static void sigmoid_backward_avx512(
const float *input,
98 const float *d_output,
102 const __m512 one = _mm512_set1_ps(1.0f);
104 for (; i + 16 <= n; i += 16) {
105 __m512 in_vec = _mm512_loadu_ps(input + i);
106 __m512 s = sigmoid_avx512_vec(in_vec);
107 __m512 dout = _mm512_loadu_ps(d_output + i);
108 __m512 grad = _mm512_mul_ps(_mm512_mul_ps(s, _mm512_sub_ps(one, s)), dout);
109 _mm512_storeu_ps(d_input + i, grad);
115 float s_prime = s * (1.0f - s);
116 d_input[i] = d_output[i] * s_prime;
126 #if defined(__AVX512F__)
127 sigmoid_forward_avx512(input, output, n);
129 for (
size_t i = 0; i < n; ++i) {
139 const float *d_output,
143 #if defined(__AVX512F__)
144 sigmoid_backward_avx512(input, d_output, d_input, n);
146 for (
size_t i = 0; i < n; ++i) {
149 float s_prime = s * (1.0f - s);
150 d_input[i] = d_output[i] * s_prime;
float sigmoid_scalar(float x)
void sigmoid_backward(const float *input, const float *d_output, float *d_input, size_t n)
void sigmoid_forward(const float *input, float *output, size_t n)