ReLU activation kernels with SIMD (SSE/AVX/AVX512)
More...
#include <math.h>
#include <stddef.h>
#include <stdint.h>
Go to the source code of this file.
ReLU activation kernels with SIMD (SSE/AVX/AVX512)
CK-ENGINE KERNEL RULES:
- NO malloc/free - memory via bump allocator, pointers passed in
- NO OpenMP - parallelization at orchestrator/codegen layer
- API must define: inputs, outputs, workspace, and memory layouts
- Pure computation - deterministic, no side effects
After changes: make test && make llamacpp-parity-full
ReLU: y = max(0, x)
Definition in file relu_kernels.c.
◆ relu_backward()
| void relu_backward |
( |
const float * |
input, |
|
|
const float * |
d_output, |
|
|
float * |
d_input, |
|
|
size_t |
n |
|
) |
| |
Definition at line 84 of file relu_kernels.c.
91 #if defined(__AVX512F__)
92 __m512 vzero = _mm512_setzero_ps();
93 for (; i + 15 < n; i += 16) {
94 __m512 vx = _mm512_loadu_ps(input + i);
95 __m512 vdy = _mm512_loadu_ps(d_output + i);
96 __mmask16
mask = _mm512_cmp_ps_mask(vx, vzero, _CMP_GT_OQ);
97 __m512 vdx = _mm512_maskz_mov_ps(
mask, vdy);
98 _mm512_storeu_ps(d_input + i, vdx);
100 #elif defined(__AVX2__) || defined(__AVX__)
101 __m256 vzero = _mm256_setzero_ps();
102 for (; i + 7 < n; i += 8) {
103 __m256 vx = _mm256_loadu_ps(input + i);
104 __m256 vdy = _mm256_loadu_ps(d_output + i);
106 __m256
mask = _mm256_cmp_ps(vx, vzero, _CMP_GT_OQ);
107 __m256 vdx = _mm256_and_ps(
mask, vdy);
108 _mm256_storeu_ps(d_input + i, vdx);
114 d_input[i] = (input[i] > 0.0f) ? d_output[i] : 0.0f;
int32_t int32_t int32_t int32_t int32_t mask
References mask.
◆ relu_forward()
| void relu_forward |
( |
const float * |
input, |
|
|
float * |
output, |
|
|
size_t |
n |
|
) |
| |
Definition at line 26 of file relu_kernels.c.
30 #if defined(__AVX512F__)
31 __m512 vzero = _mm512_setzero_ps();
32 for (; i + 15 < n; i += 16) {
33 __m512 vx = _mm512_loadu_ps(input + i);
34 __m512 vy = _mm512_max_ps(vx, vzero);
35 _mm512_storeu_ps(output + i, vy);
37 #elif defined(__AVX2__) || defined(__AVX__)
38 __m256 vzero = _mm256_setzero_ps();
39 for (; i + 7 < n; i += 8) {
40 __m256 vx = _mm256_loadu_ps(input + i);
41 __m256 vy = _mm256_max_ps(vx, vzero);
42 _mm256_storeu_ps(output + i, vy);
49 output[i] = (x > 0.0f) ? x : 0.0f;
◆ relu_forward_inplace()
| void relu_forward_inplace |
( |
float * |
data, |
|
|
size_t |
n |
|
) |
| |
Definition at line 54 of file relu_kernels.c.
58 #if defined(__AVX512F__)
59 __m512 vzero = _mm512_setzero_ps();
60 for (; i + 15 < n; i += 16) {
61 __m512 vx = _mm512_loadu_ps(data + i);
62 __m512 vy = _mm512_max_ps(vx, vzero);
63 _mm512_storeu_ps(data + i, vy);
65 #elif defined(__AVX2__) || defined(__AVX__)
66 __m256 vzero = _mm256_setzero_ps();
67 for (; i + 7 < n; i += 8) {
68 __m256 vx = _mm256_loadu_ps(data + i);
69 __m256 vy = _mm256_max_ps(vx, vzero);
70 _mm256_storeu_ps(data + i, vy);