← Back to C-Kernel-Engine Docs Doxygen Source Documentation
optimizer_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file optimizer_kernels_bf16.c
3  * @brief BF16 optimizer kernels for training
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Note: Optimizer state (m, v) is always kept in fp32 for numerical stability.
15  * Only weights and gradients are in bf16.
16  */
17 
18 #include <math.h>
19 #include <stddef.h>
20 #include <stdint.h>
21 #include <stdlib.h>
22 #include <string.h>
23 
24 #include "bf16_utils.h"
25 
26 /* Forward declarations of fp32 kernels */
27 extern void adamw_update_f32(
28  const float *grad, float *weight, float *m, float *v, size_t numel,
29  float lr, float beta1, float beta2, float eps, float weight_decay, int step);
30 
31 extern void sgd_momentum_update_f32(
32  const float *grad, float *weight, float *velocity, size_t numel,
33  float lr, float momentum, float weight_decay);
34 
35 extern void gradient_accumulate_f32(float *dst, const float *src, size_t numel);
36 extern void gradient_scale_f32(float *grad, size_t numel, float scale);
37 
38 
39 /**
40  * @brief AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)
41  *
42  * Weights and gradients are in bf16 for memory efficiency.
43  * Momentum (m) and variance (v) are in fp32 for numerical stability.
44  *
45  * @param grad Gradient tensor (bf16) [numel]
46  * @param weight Weight tensor to update (bf16, in-place) [numel]
47  * @param m First moment buffer (fp32, in-place) [numel]
48  * @param v Second moment buffer (fp32, in-place) [numel]
49  * @param numel Number of elements
50  * @param lr Learning rate
51  * @param beta1 First moment decay (typically 0.9)
52  * @param beta2 Second moment decay (typically 0.999)
53  * @param eps Numerical stability constant (typically 1e-8)
54  * @param weight_decay Weight decay coefficient
55  * @param step Current step number (1-indexed)
56  */
58  const uint16_t *grad,
59  uint16_t *weight,
60  float *m,
61  float *v,
62  size_t numel,
63  float lr,
64  float beta1,
65  float beta2,
66  float eps,
67  float weight_decay,
68  int step)
69 {
70  if (!grad || !weight || !m || !v || numel == 0) {
71  return;
72  }
73 
74  // Bias correction terms
75  float bias_correction1 = 1.0f - powf(beta1, (float)step);
76  float bias_correction2 = 1.0f - powf(beta2, (float)step);
77  float one_minus_beta1 = 1.0f - beta1;
78  float one_minus_beta2 = 1.0f - beta2;
79 
80 #if defined(__AVX512F__)
81  // Vectorized path: process 16 elements at a time
82  __m512 v_beta1 = _mm512_set1_ps(beta1);
83  __m512 v_beta2 = _mm512_set1_ps(beta2);
84  __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
85  __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
86  __m512 v_lr = _mm512_set1_ps(lr);
87  __m512 v_eps = _mm512_set1_ps(eps);
88  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
89  __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
90  __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
91 
92  size_t i = 0;
93  for (; i + 16 <= numel; i += 16) {
94  // Load bf16 gradient and weight, convert to fp32
95  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
96  __m512 w = bf16_loadu_cvt_fp32(&weight[i]);
97 
98  // Load fp32 optimizer state
99  __m512 m_val = _mm512_loadu_ps(&m[i]);
100  __m512 v_val = _mm512_loadu_ps(&v[i]);
101 
102  // Update m: m = beta1 * m + (1 - beta1) * g
103  m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
104 
105  // Update v: v = beta2 * v + (1 - beta2) * g^2
106  __m512 g_sq = _mm512_mul_ps(g, g);
107  v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
108 
109  // Bias-corrected estimates
110  __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
111  __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
112 
113  // Update weight: w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
114  __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
115  __m512 update = _mm512_div_ps(m_hat, denom);
116  update = _mm512_fmadd_ps(v_weight_decay, w, update);
117  w = _mm512_fnmadd_ps(v_lr, update, w);
118 
119  // Store updated weight as bf16
120  fp32_cvt_storeu_bf16(&weight[i], w);
121 
122  // Store updated optimizer state (stays fp32)
123  _mm512_storeu_ps(&m[i], m_val);
124  _mm512_storeu_ps(&v[i], v_val);
125  }
126 
127  // Scalar tail
128  for (; i < numel; ++i) {
129  float g = bf16_to_float(grad[i]);
130  float w = bf16_to_float(weight[i]);
131 
132  m[i] = beta1 * m[i] + one_minus_beta1 * g;
133  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
134 
135  float m_hat = m[i] / bias_correction1;
136  float v_hat = v[i] / bias_correction2;
137 
138  w = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
139  weight[i] = float_to_bf16(w);
140  }
141 #else
142  // Scalar path
143  for (size_t i = 0; i < numel; ++i) {
144  float g = bf16_to_float(grad[i]);
145  float w = bf16_to_float(weight[i]);
146 
147  m[i] = beta1 * m[i] + one_minus_beta1 * g;
148  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
149 
150  float m_hat = m[i] / bias_correction1;
151  float v_hat = v[i] / bias_correction2;
152 
153  w = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
154  weight[i] = float_to_bf16(w);
155  }
156 #endif
157 }
158 
159 
160 /**
161  * @brief SGD with momentum (bf16 weights/gradients)
162  */
164  const uint16_t *grad,
165  uint16_t *weight,
166  float *velocity,
167  size_t numel,
168  float lr,
169  float momentum,
170  float weight_decay)
171 {
172  if (!grad || !weight || !velocity || numel == 0) {
173  return;
174  }
175 
176 #if defined(__AVX512F__)
177  __m512 v_lr = _mm512_set1_ps(lr);
178  __m512 v_momentum = _mm512_set1_ps(momentum);
179  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
180 
181  size_t i = 0;
182  for (; i + 16 <= numel; i += 16) {
183  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
184  __m512 w = bf16_loadu_cvt_fp32(&weight[i]);
185  __m512 vel = _mm512_loadu_ps(&velocity[i]);
186 
187  vel = _mm512_fmadd_ps(v_momentum, vel, g);
188  __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
189  w = _mm512_fnmadd_ps(v_lr, update, w);
190 
191  fp32_cvt_storeu_bf16(&weight[i], w);
192  _mm512_storeu_ps(&velocity[i], vel);
193  }
194 
195  for (; i < numel; ++i) {
196  float g = bf16_to_float(grad[i]);
197  float w = bf16_to_float(weight[i]);
198  velocity[i] = momentum * velocity[i] + g;
199  w = w - lr * (velocity[i] + weight_decay * w);
200  weight[i] = float_to_bf16(w);
201  }
202 #else
203  for (size_t i = 0; i < numel; ++i) {
204  float g = bf16_to_float(grad[i]);
205  float w = bf16_to_float(weight[i]);
206  velocity[i] = momentum * velocity[i] + g;
207  w = w - lr * (velocity[i] + weight_decay * w);
208  weight[i] = float_to_bf16(w);
209  }
210 #endif
211 }
212 
213 
214 /**
215  * @brief Zero out gradient buffer (bf16)
216  */
217 void zero_gradients_bf16(uint16_t *grad, size_t numel)
218 {
219  if (!grad || numel == 0) {
220  return;
221  }
222  memset(grad, 0, numel * sizeof(uint16_t));
223 }
224 
225 
226 /**
227  * @brief Accumulate gradients: dst += src (bf16)
228  */
229 void gradient_accumulate_bf16(uint16_t *dst, const uint16_t *src, size_t numel)
230 {
231  if (!dst || !src || numel == 0) {
232  return;
233  }
234 
235 #if defined(__AVX512F__)
236  size_t i = 0;
237  for (; i + 16 <= numel; i += 16) {
238  __m512 d = bf16_loadu_cvt_fp32(&dst[i]);
239  __m512 s = bf16_loadu_cvt_fp32(&src[i]);
240  fp32_cvt_storeu_bf16(&dst[i], _mm512_add_ps(d, s));
241  }
242  for (; i < numel; ++i) {
243  float d = bf16_to_float(dst[i]);
244  float s = bf16_to_float(src[i]);
245  dst[i] = float_to_bf16(d + s);
246  }
247 #else
248  for (size_t i = 0; i < numel; ++i) {
249  float d = bf16_to_float(dst[i]);
250  float s = bf16_to_float(src[i]);
251  dst[i] = float_to_bf16(d + s);
252  }
253 #endif
254 }
255 
256 
257 /**
258  * @brief Scale gradients: grad *= scale (bf16)
259  */
260 void gradient_scale_bf16(uint16_t *grad, size_t numel, float scale)
261 {
262  if (!grad || numel == 0) {
263  return;
264  }
265 
266 #if defined(__AVX512F__)
267  __m512 v_scale = _mm512_set1_ps(scale);
268  size_t i = 0;
269  for (; i + 16 <= numel; i += 16) {
270  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
271  fp32_cvt_storeu_bf16(&grad[i], _mm512_mul_ps(g, v_scale));
272  }
273  for (; i < numel; ++i) {
274  float g = bf16_to_float(grad[i]);
275  grad[i] = float_to_bf16(g * scale);
276  }
277 #else
278  for (size_t i = 0; i < numel; ++i) {
279  float g = bf16_to_float(grad[i]);
280  grad[i] = float_to_bf16(g * scale);
281  }
282 #endif
283 }
284 
285 
286 /**
287  * @brief Clip gradient norm (bf16)
288  *
289  * @return The original L2 norm before clipping
290  */
291 float gradient_clip_norm_bf16(uint16_t *grad, size_t numel, float max_norm)
292 {
293  if (!grad || numel == 0 || max_norm <= 0.0f) {
294  return 0.0f;
295  }
296 
297  // Compute L2 norm in fp32 for accuracy
298  double sum_sq = 0.0;
299 #if defined(__AVX512F__)
300  __m512 acc = _mm512_setzero_ps();
301  size_t i = 0;
302  for (; i + 16 <= numel; i += 16) {
303  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
304  acc = _mm512_fmadd_ps(g, g, acc);
305  }
306  sum_sq = _mm512_reduce_add_ps(acc);
307  for (; i < numel; ++i) {
308  float g = bf16_to_float(grad[i]);
309  sum_sq += (double)g * (double)g;
310  }
311 #else
312  for (size_t i = 0; i < numel; ++i) {
313  float g = bf16_to_float(grad[i]);
314  sum_sq += (double)g * (double)g;
315  }
316 #endif
317 
318  float norm = sqrtf((float)sum_sq);
319 
320  if (norm > max_norm) {
321  float scale = max_norm / norm;
322  gradient_scale_bf16(grad, numel, scale);
323  }
324 
325  return norm;
326 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
void adamw_update_bf16(const uint16_t *grad, uint16_t *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)
void zero_gradients_bf16(uint16_t *grad, size_t numel)
Zero out gradient buffer (bf16)
void sgd_momentum_update_f32(const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum optimizer update (fp32 version)
float gradient_clip_norm_bf16(uint16_t *grad, size_t numel, float max_norm)
Clip gradient norm (bf16)
void gradient_accumulate_bf16(uint16_t *dst, const uint16_t *src, size_t numel)
Accumulate gradients: dst += src (bf16)
void gradient_scale_bf16(uint16_t *grad, size_t numel, float scale)
Scale gradients: grad *= scale (bf16)
void sgd_momentum_update_bf16(const uint16_t *grad, uint16_t *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum (bf16 weights/gradients)
void gradient_scale_f32(float *grad, size_t numel, float scale)
Scale gradients by a constant: grad *= scale (fp32)
void adamw_update_f32(const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (fp32 version)
void gradient_accumulate_f32(float *dst, const float *src, size_t numel)
Accumulate gradients: dst += src (fp32)