21 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
22 #include <immintrin.h>
30 #if defined(__AVX512F__)
31 __m512 vzero = _mm512_setzero_ps();
32 for (; i + 15 < n; i += 16) {
33 __m512 vx = _mm512_loadu_ps(input + i);
34 __m512 vy = _mm512_max_ps(vx, vzero);
35 _mm512_storeu_ps(output + i, vy);
37 #elif defined(__AVX2__) || defined(__AVX__)
38 __m256 vzero = _mm256_setzero_ps();
39 for (; i + 7 < n; i += 8) {
40 __m256 vx = _mm256_loadu_ps(input + i);
41 __m256 vy = _mm256_max_ps(vx, vzero);
42 _mm256_storeu_ps(output + i, vy);
49 output[i] = (x > 0.0f) ? x : 0.0f;
58 #if defined(__AVX512F__)
59 __m512 vzero = _mm512_setzero_ps();
60 for (; i + 15 < n; i += 16) {
61 __m512 vx = _mm512_loadu_ps(data + i);
62 __m512 vy = _mm512_max_ps(vx, vzero);
63 _mm512_storeu_ps(data + i, vy);
65 #elif defined(__AVX2__) || defined(__AVX__)
66 __m256 vzero = _mm256_setzero_ps();
67 for (; i + 7 < n; i += 8) {
68 __m256 vx = _mm256_loadu_ps(data + i);
69 __m256 vy = _mm256_max_ps(vx, vzero);
70 _mm256_storeu_ps(data + i, vy);
85 const float *d_output,
91 #if defined(__AVX512F__)
92 __m512 vzero = _mm512_setzero_ps();
93 for (; i + 15 < n; i += 16) {
94 __m512 vx = _mm512_loadu_ps(input + i);
95 __m512 vdy = _mm512_loadu_ps(d_output + i);
96 __mmask16
mask = _mm512_cmp_ps_mask(vx, vzero, _CMP_GT_OQ);
97 __m512 vdx = _mm512_maskz_mov_ps(
mask, vdy);
98 _mm512_storeu_ps(d_input + i, vdx);
100 #elif defined(__AVX2__) || defined(__AVX__)
101 __m256 vzero = _mm256_setzero_ps();
102 for (; i + 7 < n; i += 8) {
103 __m256 vx = _mm256_loadu_ps(input + i);
104 __m256 vdy = _mm256_loadu_ps(d_output + i);
106 __m256
mask = _mm256_cmp_ps(vx, vzero, _CMP_GT_OQ);
107 __m256 vdx = _mm256_and_ps(
mask, vdy);
108 _mm256_storeu_ps(d_input + i, vdx);
114 d_input[i] = (input[i] > 0.0f) ? d_output[i] : 0.0f;
void relu_backward(const float *input, const float *d_output, float *d_input, size_t n)
void relu_forward(const float *input, float *output, size_t n)
void relu_forward_inplace(float *data, size_t n)
int32_t int32_t int32_t int32_t int32_t mask