← Back to C-Kernel-Engine Docs Doxygen Source Documentation
add_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file add_kernels_bf16.c
3  * @brief Element-wise addition kernels for BF16 tensors
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Used for residual connections in transformer models:
15  * residual = x + sublayer_output
16  *
17  * Supports:
18  * - Forward: y = a + b
19  * - Forward with scale: y = a + alpha * b
20  * - Backward: d_a = d_y, d_b = d_y (gradient flows through unchanged)
21  * - In-place: a += b
22  */
23 
24 #include "bf16_utils.h"
25 #include "ckernel_engine.h"
26 
27 #include <stdint.h>
28 #include <stddef.h>
29 
30 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
31 #include <immintrin.h>
32 #endif
33 
34 /* =============================================================================
35  * Forward: y = a + b
36  * ============================================================================= */
37 
38 void add_forward_bf16(const uint16_t *a,
39  const uint16_t *b,
40  uint16_t *y,
41  size_t n)
42 {
43  if (!a || !b || !y || n == 0) {
44  return;
45  }
46 
47  size_t i = 0;
48 
49 #if defined(__AVX512F__)
50  /* AVX-512: Process 16 bf16 elements at a time */
51  for (; i + 16 <= n; i += 16) {
52  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
53  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
54  __m512 yv = _mm512_add_ps(av, bv);
55  fp32_cvt_storeu_bf16(&y[i], yv);
56  }
57 #endif
58 
59  /* Scalar fallback */
60  for (; i < n; ++i) {
61  float af = bf16_to_float(a[i]);
62  float bf = bf16_to_float(b[i]);
63  y[i] = float_to_bf16(af + bf);
64  }
65 }
66 
67 /* =============================================================================
68  * Forward with scale: y = a + alpha * b
69  * Useful for gradient accumulation or weighted residuals
70  * ============================================================================= */
71 
72 void add_scaled_forward_bf16(const uint16_t *a,
73  const uint16_t *b,
74  uint16_t *y,
75  float alpha,
76  size_t n)
77 {
78  if (!a || !b || !y || n == 0) {
79  return;
80  }
81 
82  size_t i = 0;
83 
84 #if defined(__AVX512F__)
85  __m512 alpha_v = _mm512_set1_ps(alpha);
86  for (; i + 16 <= n; i += 16) {
87  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
88  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
89  __m512 yv = _mm512_fmadd_ps(bv, alpha_v, av); /* a + alpha * b */
90  fp32_cvt_storeu_bf16(&y[i], yv);
91  }
92 #endif
93 
94  for (; i < n; ++i) {
95  float af = bf16_to_float(a[i]);
96  float bf = bf16_to_float(b[i]);
97  y[i] = float_to_bf16(af + alpha * bf);
98  }
99 }
100 
101 /* =============================================================================
102  * In-place: a += b
103  * ============================================================================= */
104 
105 void add_inplace_bf16(uint16_t *a,
106  const uint16_t *b,
107  size_t n)
108 {
109  if (!a || !b || n == 0) {
110  return;
111  }
112 
113  size_t i = 0;
114 
115 #if defined(__AVX512F__)
116  for (; i + 16 <= n; i += 16) {
117  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
118  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
119  __m512 yv = _mm512_add_ps(av, bv);
120  fp32_cvt_storeu_bf16(&a[i], yv);
121  }
122 #endif
123 
124  for (; i < n; ++i) {
125  float af = bf16_to_float(a[i]);
126  float bf = bf16_to_float(b[i]);
127  a[i] = float_to_bf16(af + bf);
128  }
129 }
130 
131 /* =============================================================================
132  * In-place scaled: a += alpha * b
133  * ============================================================================= */
134 
135 void add_scaled_inplace_bf16(uint16_t *a,
136  const uint16_t *b,
137  float alpha,
138  size_t n)
139 {
140  if (!a || !b || n == 0) {
141  return;
142  }
143 
144  size_t i = 0;
145 
146 #if defined(__AVX512F__)
147  __m512 alpha_v = _mm512_set1_ps(alpha);
148  for (; i + 16 <= n; i += 16) {
149  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
150  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
151  __m512 yv = _mm512_fmadd_ps(bv, alpha_v, av);
152  fp32_cvt_storeu_bf16(&a[i], yv);
153  }
154 #endif
155 
156  for (; i < n; ++i) {
157  float af = bf16_to_float(a[i]);
158  float bf = bf16_to_float(b[i]);
159  a[i] = float_to_bf16(af + alpha * bf);
160  }
161 }
162 
163 /* =============================================================================
164  * Backward: d_a = d_y, d_b = d_y
165  *
166  * For y = a + b, gradients pass through unchanged:
167  * dy/da = 1, dy/db = 1
168  *
169  * This is a simple copy operation, but we provide it for API consistency.
170  * If d_a == d_y or d_b == d_y (in-place), no copy needed.
171  * ============================================================================= */
172 
173 void add_backward_bf16(const uint16_t *d_y,
174  uint16_t *d_a,
175  uint16_t *d_b,
176  size_t n)
177 {
178  if (!d_y || n == 0) {
179  return;
180  }
181 
182  size_t i = 0;
183 
184  /* Copy to d_a if not in-place */
185  if (d_a && d_a != d_y) {
186 #if defined(__AVX512F__)
187  for (; i + 32 <= n; i += 32) {
188  __m512i v0 = _mm512_loadu_si512((const __m512i*)&d_y[i]);
189  __m512i v1 = _mm512_loadu_si512((const __m512i*)&d_y[i + 32]);
190  _mm512_storeu_si512((__m512i*)&d_a[i], v0);
191  _mm512_storeu_si512((__m512i*)&d_a[i + 32], v1);
192  }
193 #endif
194  for (; i < n; ++i) {
195  d_a[i] = d_y[i];
196  }
197  }
198 
199  /* Copy to d_b if not in-place */
200  i = 0;
201  if (d_b && d_b != d_y) {
202 #if defined(__AVX512F__)
203  for (; i + 32 <= n; i += 32) {
204  __m512i v0 = _mm512_loadu_si512((const __m512i*)&d_y[i]);
205  __m512i v1 = _mm512_loadu_si512((const __m512i*)&d_y[i + 32]);
206  _mm512_storeu_si512((__m512i*)&d_b[i], v0);
207  _mm512_storeu_si512((__m512i*)&d_b[i + 32], v1);
208  }
209 #endif
210  for (; i < n; ++i) {
211  d_b[i] = d_y[i];
212  }
213  }
214 }
215 
216 /* =============================================================================
217  * 2D tensor version: add_forward_2d_bf16
218  * For [T, D] shaped tensors (common in transformers)
219  * ============================================================================= */
220 
221 void add_forward_2d_bf16(const uint16_t *a,
222  const uint16_t *b,
223  uint16_t *y,
224  int tokens,
225  int dim,
226  int aligned_dim)
227 {
228  if (!a || !b || !y || tokens <= 0 || dim <= 0) {
229  return;
230  }
231 
232  for (int t = 0; t < tokens; ++t) {
233  const uint16_t *a_row = a + (size_t)t * aligned_dim;
234  const uint16_t *b_row = b + (size_t)t * aligned_dim;
235  uint16_t *y_row = y + (size_t)t * aligned_dim;
236 
237  int d = 0;
238 
239 #if defined(__AVX512F__)
240  for (; d + 16 <= dim; d += 16) {
241  __m512 av = bf16_loadu_cvt_fp32(&a_row[d]);
242  __m512 bv = bf16_loadu_cvt_fp32(&b_row[d]);
243  __m512 yv = _mm512_add_ps(av, bv);
244  fp32_cvt_storeu_bf16(&y_row[d], yv);
245  }
246 #endif
247 
248  for (; d < dim; ++d) {
249  float af = bf16_to_float(a_row[d]);
250  float bf = bf16_to_float(b_row[d]);
251  y_row[d] = float_to_bf16(af + bf);
252  }
253  }
254 }
255 
256 /* =============================================================================
257  * FP32 versions (for reference/testing)
258  * ============================================================================= */
259 
260 /**
261  * Element-wise add: y = a + b
262  * @test test_add.py::TestAddForward::test_add_forward_f32
263  * @test test_add.py::TestAddForward::test_add_inplace_f32
264  * @test test_multi_layer_parity.py::TestMultiLayerParity::test_residual_add
265  *
266  * Element-wise addition of two vectors.
267  *
268  * After changes: make test
269  */
270 void add_forward_f32(const float *a,
271  const float *b,
272  float *y,
273  size_t n)
274 {
275  if (!a || !b || !y || n == 0) {
276  return;
277  }
278 
279  size_t i = 0;
280 
281 #if defined(__AVX512F__)
282  for (; i + 16 <= n; i += 16) {
283  __m512 av = _mm512_loadu_ps(&a[i]);
284  __m512 bv = _mm512_loadu_ps(&b[i]);
285  __m512 yv = _mm512_add_ps(av, bv);
286  _mm512_storeu_ps(&y[i], yv);
287  }
288 #endif
289 
290 #if defined(__AVX2__)
291  for (; i + 8 <= n; i += 8) {
292  __m256 av = _mm256_loadu_ps(&a[i]);
293  __m256 bv = _mm256_loadu_ps(&b[i]);
294  __m256 yv = _mm256_add_ps(av, bv);
295  _mm256_storeu_ps(&y[i], yv);
296  }
297 #endif
298 
299  for (; i < n; ++i) {
300  y[i] = a[i] + b[i];
301  }
302 }
303 
304 void add_inplace_f32(float *a,
305  const float *b,
306  size_t n)
307 {
308  if (!a || !b || n == 0) {
309  return;
310  }
311 
312  size_t i = 0;
313 
314 #if defined(__AVX512F__)
315  for (; i + 16 <= n; i += 16) {
316  __m512 av = _mm512_loadu_ps(&a[i]);
317  __m512 bv = _mm512_loadu_ps(&b[i]);
318  __m512 yv = _mm512_add_ps(av, bv);
319  _mm512_storeu_ps(&a[i], yv);
320  }
321 #endif
322 
323 #if defined(__AVX2__)
324  for (; i + 8 <= n; i += 8) {
325  __m256 av = _mm256_loadu_ps(&a[i]);
326  __m256 bv = _mm256_loadu_ps(&b[i]);
327  __m256 yv = _mm256_add_ps(av, bv);
328  _mm256_storeu_ps(&a[i], yv);
329  }
330 #endif
331 
332  for (; i < n; ++i) {
333  a[i] = a[i] + b[i];
334  }
335 }
void add_scaled_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, float alpha, size_t n)
void add_inplace_bf16(uint16_t *a, const uint16_t *b, size_t n)
void add_forward_f32(const float *a, const float *b, float *y, size_t n)
void add_inplace_f32(float *a, const float *b, size_t n)
void add_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, size_t n)
void add_forward_2d_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, int tokens, int dim, int aligned_dim)
void add_backward_bf16(const uint16_t *d_y, uint16_t *d_a, uint16_t *d_b, size_t n)
void add_scaled_inplace_bf16(uint16_t *a, const uint16_t *b, float alpha, size_t n)
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38