← Back to C-Kernel-Engine Docs Doxygen Source Documentation
relu_kernels.c
Go to the documentation of this file.
1 /**
2  * @file relu_kernels.c
3  * @brief ReLU activation kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * ReLU: y = max(0, x)
15  */
16 
17 #include <math.h>
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
22 #include <immintrin.h>
23 #endif
24 
25 /* ReLU forward: y = max(0, x) */
26 void relu_forward(const float *input, float *output, size_t n)
27 {
28  size_t i = 0;
29 
30 #if defined(__AVX512F__)
31  __m512 vzero = _mm512_setzero_ps();
32  for (; i + 15 < n; i += 16) {
33  __m512 vx = _mm512_loadu_ps(input + i);
34  __m512 vy = _mm512_max_ps(vx, vzero);
35  _mm512_storeu_ps(output + i, vy);
36  }
37 #elif defined(__AVX2__) || defined(__AVX__)
38  __m256 vzero = _mm256_setzero_ps();
39  for (; i + 7 < n; i += 8) {
40  __m256 vx = _mm256_loadu_ps(input + i);
41  __m256 vy = _mm256_max_ps(vx, vzero);
42  _mm256_storeu_ps(output + i, vy);
43  }
44 #endif
45 
46  // Scalar fallback
47  for (; i < n; ++i) {
48  float x = input[i];
49  output[i] = (x > 0.0f) ? x : 0.0f;
50  }
51 }
52 
53 // ReLU forward in-place: x = max(0, x)
54 void relu_forward_inplace(float *data, size_t n)
55 {
56  size_t i = 0;
57 
58 #if defined(__AVX512F__)
59  __m512 vzero = _mm512_setzero_ps();
60  for (; i + 15 < n; i += 16) {
61  __m512 vx = _mm512_loadu_ps(data + i);
62  __m512 vy = _mm512_max_ps(vx, vzero);
63  _mm512_storeu_ps(data + i, vy);
64  }
65 #elif defined(__AVX2__) || defined(__AVX__)
66  __m256 vzero = _mm256_setzero_ps();
67  for (; i + 7 < n; i += 8) {
68  __m256 vx = _mm256_loadu_ps(data + i);
69  __m256 vy = _mm256_max_ps(vx, vzero);
70  _mm256_storeu_ps(data + i, vy);
71  }
72 #endif
73 
74  // Scalar fallback
75  for (; i < n; ++i) {
76  float x = data[i];
77  if (x < 0.0f) {
78  data[i] = 0.0f;
79  }
80  }
81 }
82 
83 // ReLU backward: dx = (x > 0) ? dy : 0
84 void relu_backward(const float *input,
85  const float *d_output,
86  float *d_input,
87  size_t n)
88 {
89  size_t i = 0;
90 
91 #if defined(__AVX512F__)
92  __m512 vzero = _mm512_setzero_ps();
93  for (; i + 15 < n; i += 16) {
94  __m512 vx = _mm512_loadu_ps(input + i);
95  __m512 vdy = _mm512_loadu_ps(d_output + i);
96  __mmask16 mask = _mm512_cmp_ps_mask(vx, vzero, _CMP_GT_OQ);
97  __m512 vdx = _mm512_maskz_mov_ps(mask, vdy);
98  _mm512_storeu_ps(d_input + i, vdx);
99  }
100 #elif defined(__AVX2__) || defined(__AVX__)
101  __m256 vzero = _mm256_setzero_ps();
102  for (; i + 7 < n; i += 8) {
103  __m256 vx = _mm256_loadu_ps(input + i);
104  __m256 vdy = _mm256_loadu_ps(d_output + i);
105  // Result is all 1s (0xFFFFFFFF) if true, 0 if false.
106  __m256 mask = _mm256_cmp_ps(vx, vzero, _CMP_GT_OQ);
107  __m256 vdx = _mm256_and_ps(mask, vdy);
108  _mm256_storeu_ps(d_input + i, vdx);
109  }
110 #endif
111 
112  // Scalar fallback
113  for (; i < n; ++i) {
114  d_input[i] = (input[i] > 0.0f) ? d_output[i] : 0.0f;
115  }
116 }
void relu_backward(const float *input, const float *d_output, float *d_input, size_t n)
Definition: relu_kernels.c:84
void relu_forward(const float *input, float *output, size_t n)
Definition: relu_kernels.c:26
void relu_forward_inplace(float *data, size_t n)
Definition: relu_kernels.c:54
int32_t int32_t int32_t int32_t int32_t mask
Definition: tokenizer.h:233