← Back to C-Kernel-Engine Docs Doxygen Source Documentation
optimizer_kernels_bf16.c File Reference

BF16 optimizer kernels for training. More...

#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include "bf16_utils.h"

Go to the source code of this file.

Functions

void adamw_update_bf16 (const uint16_t *grad, uint16_t *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
 AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state) More...
 
void adamw_update_f32 (const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
 AdamW optimizer update (fp32 version) More...
 
void gradient_accumulate_bf16 (uint16_t *dst, const uint16_t *src, size_t numel)
 Accumulate gradients: dst += src (bf16) More...
 
void gradient_accumulate_f32 (float *dst, const float *src, size_t numel)
 Accumulate gradients: dst += src (fp32) More...
 
float gradient_clip_norm_bf16 (uint16_t *grad, size_t numel, float max_norm)
 Clip gradient norm (bf16) More...
 
void gradient_scale_bf16 (uint16_t *grad, size_t numel, float scale)
 Scale gradients: grad *= scale (bf16) More...
 
void gradient_scale_f32 (float *grad, size_t numel, float scale)
 Scale gradients by a constant: grad *= scale (fp32) More...
 
void sgd_momentum_update_bf16 (const uint16_t *grad, uint16_t *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
 SGD with momentum (bf16 weights/gradients) More...
 
void sgd_momentum_update_f32 (const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
 SGD with momentum optimizer update (fp32 version) More...
 
void zero_gradients_bf16 (uint16_t *grad, size_t numel)
 Zero out gradient buffer (bf16) More...
 

Detailed Description

BF16 optimizer kernels for training.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Note: Optimizer state (m, v) is always kept in fp32 for numerical stability. Only weights and gradients are in bf16.

Definition in file optimizer_kernels_bf16.c.

Function Documentation

◆ adamw_update_bf16()

void adamw_update_bf16 ( const uint16_t *  grad,
uint16_t *  weight,
float *  m,
float *  v,
size_t  numel,
float  lr,
float  beta1,
float  beta2,
float  eps,
float  weight_decay,
int  step 
)

AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)

Weights and gradients are in bf16 for memory efficiency. Momentum (m) and variance (v) are in fp32 for numerical stability.

Parameters
gradGradient tensor (bf16) [numel]
weightWeight tensor to update (bf16, in-place) [numel]
mFirst moment buffer (fp32, in-place) [numel]
vSecond moment buffer (fp32, in-place) [numel]
numelNumber of elements
lrLearning rate
beta1First moment decay (typically 0.9)
beta2Second moment decay (typically 0.999)
epsNumerical stability constant (typically 1e-8)
weight_decayWeight decay coefficient
stepCurrent step number (1-indexed)

Definition at line 57 of file optimizer_kernels_bf16.c.

69 {
70  if (!grad || !weight || !m || !v || numel == 0) {
71  return;
72  }
73 
74  // Bias correction terms
75  float bias_correction1 = 1.0f - powf(beta1, (float)step);
76  float bias_correction2 = 1.0f - powf(beta2, (float)step);
77  float one_minus_beta1 = 1.0f - beta1;
78  float one_minus_beta2 = 1.0f - beta2;
79 
80 #if defined(__AVX512F__)
81  // Vectorized path: process 16 elements at a time
82  __m512 v_beta1 = _mm512_set1_ps(beta1);
83  __m512 v_beta2 = _mm512_set1_ps(beta2);
84  __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
85  __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
86  __m512 v_lr = _mm512_set1_ps(lr);
87  __m512 v_eps = _mm512_set1_ps(eps);
88  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
89  __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
90  __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
91 
92  size_t i = 0;
93  for (; i + 16 <= numel; i += 16) {
94  // Load bf16 gradient and weight, convert to fp32
95  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
96  __m512 w = bf16_loadu_cvt_fp32(&weight[i]);
97 
98  // Load fp32 optimizer state
99  __m512 m_val = _mm512_loadu_ps(&m[i]);
100  __m512 v_val = _mm512_loadu_ps(&v[i]);
101 
102  // Update m: m = beta1 * m + (1 - beta1) * g
103  m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
104 
105  // Update v: v = beta2 * v + (1 - beta2) * g^2
106  __m512 g_sq = _mm512_mul_ps(g, g);
107  v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
108 
109  // Bias-corrected estimates
110  __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
111  __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
112 
113  // Update weight: w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
114  __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
115  __m512 update = _mm512_div_ps(m_hat, denom);
116  update = _mm512_fmadd_ps(v_weight_decay, w, update);
117  w = _mm512_fnmadd_ps(v_lr, update, w);
118 
119  // Store updated weight as bf16
120  fp32_cvt_storeu_bf16(&weight[i], w);
121 
122  // Store updated optimizer state (stays fp32)
123  _mm512_storeu_ps(&m[i], m_val);
124  _mm512_storeu_ps(&v[i], v_val);
125  }
126 
127  // Scalar tail
128  for (; i < numel; ++i) {
129  float g = bf16_to_float(grad[i]);
130  float w = bf16_to_float(weight[i]);
131 
132  m[i] = beta1 * m[i] + one_minus_beta1 * g;
133  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
134 
135  float m_hat = m[i] / bias_correction1;
136  float v_hat = v[i] / bias_correction2;
137 
138  w = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
139  weight[i] = float_to_bf16(w);
140  }
141 #else
142  // Scalar path
143  for (size_t i = 0; i < numel; ++i) {
144  float g = bf16_to_float(grad[i]);
145  float w = bf16_to_float(weight[i]);
146 
147  m[i] = beta1 * m[i] + one_minus_beta1 * g;
148  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
149 
150  float m_hat = m[i] / bias_correction1;
151  float v_hat = v[i] / bias_correction2;
152 
153  w = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
154  weight[i] = float_to_bf16(w);
155  }
156 #endif
157 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38

References bf16_to_float(), and float_to_bf16().

◆ adamw_update_f32()

void adamw_update_f32 ( const float *  grad,
float *  weight,
float *  m,
float *  v,
size_t  numel,
float  lr,
float  beta1,
float  beta2,
float  eps,
float  weight_decay,
int  step 
)

AdamW optimizer update (fp32 version)

Updates weights in-place using the AdamW algorithm. Momentum (m) and variance (v) are stored in fp32 for numerical stability.

Parameters
gradGradient tensor (fp32) [numel]
weightWeight tensor to update (fp32, in-place) [numel]
mFirst moment (momentum) buffer (fp32, in-place) [numel]
vSecond moment (variance) buffer (fp32, in-place) [numel]
numelNumber of elements
lrLearning rate
beta1Exponential decay rate for first moment (typically 0.9)
beta2Exponential decay rate for second moment (typically 0.999)
epsSmall constant for numerical stability (typically 1e-8)
weight_decayWeight decay coefficient (typically 0.01)
stepCurrent step number (1-indexed for bias correction)

Definition at line 53 of file optimizer_kernels.c.

65 {
66  if (!grad || !weight || !m || !v || numel == 0) {
67  return;
68  }
69 
70  // Bias correction terms
71  float bias_correction1 = 1.0f - powf(beta1, (float)step);
72  float bias_correction2 = 1.0f - powf(beta2, (float)step);
73 
74  // Precompute constants
75  float one_minus_beta1 = 1.0f - beta1;
76  float one_minus_beta2 = 1.0f - beta2;
77 
78 #if defined(__AVX512F__)
79  // AVX-512 path: process 16 floats at a time
80  __m512 v_beta1 = _mm512_set1_ps(beta1);
81  __m512 v_beta2 = _mm512_set1_ps(beta2);
82  __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
83  __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
84  __m512 v_lr = _mm512_set1_ps(lr);
85  __m512 v_eps = _mm512_set1_ps(eps);
86  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
87  __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
88  __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
89 
90  size_t i = 0;
91  for (; i + 16 <= numel; i += 16) {
92  __m512 g = _mm512_loadu_ps(&grad[i]);
93  __m512 w = _mm512_loadu_ps(&weight[i]);
94  __m512 m_val = _mm512_loadu_ps(&m[i]);
95  __m512 v_val = _mm512_loadu_ps(&v[i]);
96 
97  // m = beta1 * m + (1 - beta1) * g
98  m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
99 
100  // v = beta2 * v + (1 - beta2) * g^2
101  __m512 g_sq = _mm512_mul_ps(g, g);
102  v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
103 
104  // Bias-corrected estimates
105  __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
106  __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
107 
108  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
109  __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
110  __m512 update = _mm512_div_ps(m_hat, denom);
111  update = _mm512_fmadd_ps(v_weight_decay, w, update);
112  w = _mm512_fnmadd_ps(v_lr, update, w);
113 
114  _mm512_storeu_ps(&weight[i], w);
115  _mm512_storeu_ps(&m[i], m_val);
116  _mm512_storeu_ps(&v[i], v_val);
117  }
118 
119  // Scalar tail
120  for (; i < numel; ++i) {
121  float g = grad[i];
122  float w = weight[i];
123  m[i] = beta1 * m[i] + one_minus_beta1 * g;
124  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
125  float m_hat = m[i] / bias_correction1;
126  float v_hat = v[i] / bias_correction2;
127  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
128  }
129 
130 #elif defined(__AVX__)
131  // AVX path: process 8 floats at a time (no FMA on older CPUs like Ivy Bridge)
132  __m256 v_beta1 = _mm256_set1_ps(beta1);
133  __m256 v_beta2 = _mm256_set1_ps(beta2);
134  __m256 v_one_minus_beta1 = _mm256_set1_ps(one_minus_beta1);
135  __m256 v_one_minus_beta2 = _mm256_set1_ps(one_minus_beta2);
136  __m256 v_lr = _mm256_set1_ps(lr);
137  __m256 v_eps = _mm256_set1_ps(eps);
138  __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
139  __m256 v_bc1_inv = _mm256_set1_ps(1.0f / bias_correction1);
140  __m256 v_bc2_inv = _mm256_set1_ps(1.0f / bias_correction2);
141 
142  size_t i = 0;
143  for (; i + 8 <= numel; i += 8) {
144  __m256 g = _mm256_loadu_ps(&grad[i]);
145  __m256 w = _mm256_loadu_ps(&weight[i]);
146  __m256 m_val = _mm256_loadu_ps(&m[i]);
147  __m256 v_val = _mm256_loadu_ps(&v[i]);
148 
149  // m = beta1 * m + (1 - beta1) * g
150  m_val = _mm256_add_ps(_mm256_mul_ps(v_beta1, m_val),
151  _mm256_mul_ps(v_one_minus_beta1, g));
152 
153  // v = beta2 * v + (1 - beta2) * g^2
154  __m256 g_sq = _mm256_mul_ps(g, g);
155  v_val = _mm256_add_ps(_mm256_mul_ps(v_beta2, v_val),
156  _mm256_mul_ps(v_one_minus_beta2, g_sq));
157 
158  // Bias-corrected estimates
159  __m256 m_hat = _mm256_mul_ps(m_val, v_bc1_inv);
160  __m256 v_hat = _mm256_mul_ps(v_val, v_bc2_inv);
161 
162  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
163  __m256 denom = _mm256_add_ps(_mm256_sqrt_ps(v_hat), v_eps);
164  __m256 update = _mm256_div_ps(m_hat, denom);
165  update = _mm256_add_ps(update, _mm256_mul_ps(v_weight_decay, w));
166  w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
167 
168  _mm256_storeu_ps(&weight[i], w);
169  _mm256_storeu_ps(&m[i], m_val);
170  _mm256_storeu_ps(&v[i], v_val);
171  }
172 
173  // Scalar tail
174  for (; i < numel; ++i) {
175  float g = grad[i];
176  float w = weight[i];
177  m[i] = beta1 * m[i] + one_minus_beta1 * g;
178  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
179  float m_hat = m[i] / bias_correction1;
180  float v_hat = v[i] / bias_correction2;
181  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
182  }
183 
184 #elif defined(__SSE2__)
185  // SSE2 path: process 4 floats at a time
186  __m128 v_beta1 = _mm_set1_ps(beta1);
187  __m128 v_beta2 = _mm_set1_ps(beta2);
188  __m128 v_one_minus_beta1 = _mm_set1_ps(one_minus_beta1);
189  __m128 v_one_minus_beta2 = _mm_set1_ps(one_minus_beta2);
190  __m128 v_lr = _mm_set1_ps(lr);
191  __m128 v_eps = _mm_set1_ps(eps);
192  __m128 v_weight_decay = _mm_set1_ps(weight_decay);
193  __m128 v_bc1_inv = _mm_set1_ps(1.0f / bias_correction1);
194  __m128 v_bc2_inv = _mm_set1_ps(1.0f / bias_correction2);
195 
196  size_t i = 0;
197  for (; i + 4 <= numel; i += 4) {
198  __m128 g = _mm_loadu_ps(&grad[i]);
199  __m128 w = _mm_loadu_ps(&weight[i]);
200  __m128 m_val = _mm_loadu_ps(&m[i]);
201  __m128 v_val = _mm_loadu_ps(&v[i]);
202 
203  // m = beta1 * m + (1 - beta1) * g
204  m_val = _mm_add_ps(_mm_mul_ps(v_beta1, m_val),
205  _mm_mul_ps(v_one_minus_beta1, g));
206 
207  // v = beta2 * v + (1 - beta2) * g^2
208  __m128 g_sq = _mm_mul_ps(g, g);
209  v_val = _mm_add_ps(_mm_mul_ps(v_beta2, v_val),
210  _mm_mul_ps(v_one_minus_beta2, g_sq));
211 
212  // Bias-corrected estimates
213  __m128 m_hat = _mm_mul_ps(m_val, v_bc1_inv);
214  __m128 v_hat = _mm_mul_ps(v_val, v_bc2_inv);
215 
216  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
217  __m128 denom = _mm_add_ps(_mm_sqrt_ps(v_hat), v_eps);
218  __m128 update = _mm_div_ps(m_hat, denom);
219  update = _mm_add_ps(update, _mm_mul_ps(v_weight_decay, w));
220  w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
221 
222  _mm_storeu_ps(&weight[i], w);
223  _mm_storeu_ps(&m[i], m_val);
224  _mm_storeu_ps(&v[i], v_val);
225  }
226 
227  // Scalar tail
228  for (; i < numel; ++i) {
229  float g = grad[i];
230  float w = weight[i];
231  m[i] = beta1 * m[i] + one_minus_beta1 * g;
232  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
233  float m_hat = m[i] / bias_correction1;
234  float v_hat = v[i] / bias_correction2;
235  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
236  }
237 
238 #else
239  // Scalar path
240  for (size_t i = 0; i < numel; ++i) {
241  float g = grad[i];
242  float w = weight[i];
243  m[i] = beta1 * m[i] + one_minus_beta1 * g;
244  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
245  float m_hat = m[i] / bias_correction1;
246  float v_hat = v[i] / bias_correction2;
247  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
248  }
249 #endif
250 }

◆ gradient_accumulate_bf16()

void gradient_accumulate_bf16 ( uint16_t *  dst,
const uint16_t *  src,
size_t  numel 
)

Accumulate gradients: dst += src (bf16)

Definition at line 229 of file optimizer_kernels_bf16.c.

230 {
231  if (!dst || !src || numel == 0) {
232  return;
233  }
234 
235 #if defined(__AVX512F__)
236  size_t i = 0;
237  for (; i + 16 <= numel; i += 16) {
238  __m512 d = bf16_loadu_cvt_fp32(&dst[i]);
239  __m512 s = bf16_loadu_cvt_fp32(&src[i]);
240  fp32_cvt_storeu_bf16(&dst[i], _mm512_add_ps(d, s));
241  }
242  for (; i < numel; ++i) {
243  float d = bf16_to_float(dst[i]);
244  float s = bf16_to_float(src[i]);
245  dst[i] = float_to_bf16(d + s);
246  }
247 #else
248  for (size_t i = 0; i < numel; ++i) {
249  float d = bf16_to_float(dst[i]);
250  float s = bf16_to_float(src[i]);
251  dst[i] = float_to_bf16(d + s);
252  }
253 #endif
254 }

References bf16_to_float(), and float_to_bf16().

◆ gradient_accumulate_f32()

void gradient_accumulate_f32 ( float *  dst,
const float *  src,
size_t  numel 
)

Accumulate gradients: dst += src (fp32)

Used for gradient accumulation across micro-batches.

Parameters
dstDestination gradient buffer (in-place) [numel]
srcSource gradient buffer [numel]
numelNumber of elements

Definition at line 392 of file optimizer_kernels.c.

393 {
394  if (!dst || !src || numel == 0) {
395  return;
396  }
397 
398 #if defined(__AVX512F__)
399  size_t i = 0;
400  for (; i + 16 <= numel; i += 16) {
401  __m512 d = _mm512_loadu_ps(&dst[i]);
402  __m512 s = _mm512_loadu_ps(&src[i]);
403  _mm512_storeu_ps(&dst[i], _mm512_add_ps(d, s));
404  }
405  for (; i < numel; ++i) {
406  dst[i] += src[i];
407  }
408 
409 #elif defined(__AVX__)
410  size_t i = 0;
411  for (; i + 8 <= numel; i += 8) {
412  __m256 d = _mm256_loadu_ps(&dst[i]);
413  __m256 s = _mm256_loadu_ps(&src[i]);
414  _mm256_storeu_ps(&dst[i], _mm256_add_ps(d, s));
415  }
416  for (; i < numel; ++i) {
417  dst[i] += src[i];
418  }
419 
420 #elif defined(__SSE2__)
421  size_t i = 0;
422  for (; i + 4 <= numel; i += 4) {
423  __m128 d = _mm_loadu_ps(&dst[i]);
424  __m128 s = _mm_loadu_ps(&src[i]);
425  _mm_storeu_ps(&dst[i], _mm_add_ps(d, s));
426  }
427  for (; i < numel; ++i) {
428  dst[i] += src[i];
429  }
430 
431 #else
432  for (size_t i = 0; i < numel; ++i) {
433  dst[i] += src[i];
434  }
435 #endif
436 }

◆ gradient_clip_norm_bf16()

float gradient_clip_norm_bf16 ( uint16_t *  grad,
size_t  numel,
float  max_norm 
)

Clip gradient norm (bf16)

Returns
The original L2 norm before clipping

Definition at line 291 of file optimizer_kernels_bf16.c.

292 {
293  if (!grad || numel == 0 || max_norm <= 0.0f) {
294  return 0.0f;
295  }
296 
297  // Compute L2 norm in fp32 for accuracy
298  double sum_sq = 0.0;
299 #if defined(__AVX512F__)
300  __m512 acc = _mm512_setzero_ps();
301  size_t i = 0;
302  for (; i + 16 <= numel; i += 16) {
303  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
304  acc = _mm512_fmadd_ps(g, g, acc);
305  }
306  sum_sq = _mm512_reduce_add_ps(acc);
307  for (; i < numel; ++i) {
308  float g = bf16_to_float(grad[i]);
309  sum_sq += (double)g * (double)g;
310  }
311 #else
312  for (size_t i = 0; i < numel; ++i) {
313  float g = bf16_to_float(grad[i]);
314  sum_sq += (double)g * (double)g;
315  }
316 #endif
317 
318  float norm = sqrtf((float)sum_sq);
319 
320  if (norm > max_norm) {
321  float scale = max_norm / norm;
322  gradient_scale_bf16(grad, numel, scale);
323  }
324 
325  return norm;
326 }
void gradient_scale_bf16(uint16_t *grad, size_t numel, float scale)
Scale gradients: grad *= scale (bf16)

References bf16_to_float(), and gradient_scale_bf16().

◆ gradient_scale_bf16()

void gradient_scale_bf16 ( uint16_t *  grad,
size_t  numel,
float  scale 
)

Scale gradients: grad *= scale (bf16)

Definition at line 260 of file optimizer_kernels_bf16.c.

261 {
262  if (!grad || numel == 0) {
263  return;
264  }
265 
266 #if defined(__AVX512F__)
267  __m512 v_scale = _mm512_set1_ps(scale);
268  size_t i = 0;
269  for (; i + 16 <= numel; i += 16) {
270  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
271  fp32_cvt_storeu_bf16(&grad[i], _mm512_mul_ps(g, v_scale));
272  }
273  for (; i < numel; ++i) {
274  float g = bf16_to_float(grad[i]);
275  grad[i] = float_to_bf16(g * scale);
276  }
277 #else
278  for (size_t i = 0; i < numel; ++i) {
279  float g = bf16_to_float(grad[i]);
280  grad[i] = float_to_bf16(g * scale);
281  }
282 #endif
283 }

References bf16_to_float(), and float_to_bf16().

Referenced by gradient_clip_norm_bf16().

◆ gradient_scale_f32()

void gradient_scale_f32 ( float *  grad,
size_t  numel,
float  scale 
)

Scale gradients by a constant: grad *= scale (fp32)

Used for averaging gradients after accumulation: grad /= batch_size

Parameters
gradGradient tensor to scale (in-place) [numel]
numelNumber of elements
scaleScale factor (typically 1.0 / batch_size)

Definition at line 448 of file optimizer_kernels.c.

449 {
450  if (!grad || numel == 0) {
451  return;
452  }
453 
454 #if defined(__AVX512F__)
455  __m512 v_scale = _mm512_set1_ps(scale);
456  size_t i = 0;
457  for (; i + 16 <= numel; i += 16) {
458  __m512 g = _mm512_loadu_ps(&grad[i]);
459  _mm512_storeu_ps(&grad[i], _mm512_mul_ps(g, v_scale));
460  }
461  for (; i < numel; ++i) {
462  grad[i] *= scale;
463  }
464 
465 #elif defined(__AVX__)
466  __m256 v_scale = _mm256_set1_ps(scale);
467  size_t i = 0;
468  for (; i + 8 <= numel; i += 8) {
469  __m256 g = _mm256_loadu_ps(&grad[i]);
470  _mm256_storeu_ps(&grad[i], _mm256_mul_ps(g, v_scale));
471  }
472  for (; i < numel; ++i) {
473  grad[i] *= scale;
474  }
475 
476 #elif defined(__SSE2__)
477  __m128 v_scale = _mm_set1_ps(scale);
478  size_t i = 0;
479  for (; i + 4 <= numel; i += 4) {
480  __m128 g = _mm_loadu_ps(&grad[i]);
481  _mm_storeu_ps(&grad[i], _mm_mul_ps(g, v_scale));
482  }
483  for (; i < numel; ++i) {
484  grad[i] *= scale;
485  }
486 
487 #else
488  for (size_t i = 0; i < numel; ++i) {
489  grad[i] *= scale;
490  }
491 #endif
492 }

Referenced by gradient_clip_norm_f32().

◆ sgd_momentum_update_bf16()

void sgd_momentum_update_bf16 ( const uint16_t *  grad,
uint16_t *  weight,
float *  velocity,
size_t  numel,
float  lr,
float  momentum,
float  weight_decay 
)

SGD with momentum (bf16 weights/gradients)

Definition at line 163 of file optimizer_kernels_bf16.c.

171 {
172  if (!grad || !weight || !velocity || numel == 0) {
173  return;
174  }
175 
176 #if defined(__AVX512F__)
177  __m512 v_lr = _mm512_set1_ps(lr);
178  __m512 v_momentum = _mm512_set1_ps(momentum);
179  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
180 
181  size_t i = 0;
182  for (; i + 16 <= numel; i += 16) {
183  __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
184  __m512 w = bf16_loadu_cvt_fp32(&weight[i]);
185  __m512 vel = _mm512_loadu_ps(&velocity[i]);
186 
187  vel = _mm512_fmadd_ps(v_momentum, vel, g);
188  __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
189  w = _mm512_fnmadd_ps(v_lr, update, w);
190 
191  fp32_cvt_storeu_bf16(&weight[i], w);
192  _mm512_storeu_ps(&velocity[i], vel);
193  }
194 
195  for (; i < numel; ++i) {
196  float g = bf16_to_float(grad[i]);
197  float w = bf16_to_float(weight[i]);
198  velocity[i] = momentum * velocity[i] + g;
199  w = w - lr * (velocity[i] + weight_decay * w);
200  weight[i] = float_to_bf16(w);
201  }
202 #else
203  for (size_t i = 0; i < numel; ++i) {
204  float g = bf16_to_float(grad[i]);
205  float w = bf16_to_float(weight[i]);
206  velocity[i] = momentum * velocity[i] + g;
207  w = w - lr * (velocity[i] + weight_decay * w);
208  weight[i] = float_to_bf16(w);
209  }
210 #endif
211 }

References bf16_to_float(), and float_to_bf16().

◆ sgd_momentum_update_f32()

void sgd_momentum_update_f32 ( const float *  grad,
float *  weight,
float *  velocity,
size_t  numel,
float  lr,
float  momentum,
float  weight_decay 
)

SGD with momentum optimizer update (fp32 version)

v_t = momentum * v_{t-1} + g_t w_t = w_{t-1} - lr * (v_t + weight_decay * w_{t-1})

Parameters
gradGradient tensor (fp32) [numel]
weightWeight tensor to update (fp32, in-place) [numel]
velocityVelocity buffer (fp32, in-place) [numel]
numelNumber of elements
lrLearning rate
momentumMomentum coefficient (typically 0.9)
weight_decayWeight decay coefficient

Definition at line 267 of file optimizer_kernels.c.

275 {
276  if (!grad || !weight || !velocity || numel == 0) {
277  return;
278  }
279 
280 #if defined(__AVX512F__)
281  // AVX-512 path: process 16 floats at a time
282  __m512 v_lr = _mm512_set1_ps(lr);
283  __m512 v_momentum = _mm512_set1_ps(momentum);
284  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
285 
286  size_t i = 0;
287  for (; i + 16 <= numel; i += 16) {
288  __m512 g = _mm512_loadu_ps(&grad[i]);
289  __m512 w = _mm512_loadu_ps(&weight[i]);
290  __m512 vel = _mm512_loadu_ps(&velocity[i]);
291 
292  vel = _mm512_fmadd_ps(v_momentum, vel, g);
293  __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
294  w = _mm512_fnmadd_ps(v_lr, update, w);
295 
296  _mm512_storeu_ps(&weight[i], w);
297  _mm512_storeu_ps(&velocity[i], vel);
298  }
299 
300  for (; i < numel; ++i) {
301  velocity[i] = momentum * velocity[i] + grad[i];
302  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
303  }
304 
305 #elif defined(__AVX__)
306  // AVX path: process 8 floats at a time
307  __m256 v_lr = _mm256_set1_ps(lr);
308  __m256 v_momentum = _mm256_set1_ps(momentum);
309  __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
310 
311  size_t i = 0;
312  for (; i + 8 <= numel; i += 8) {
313  __m256 g = _mm256_loadu_ps(&grad[i]);
314  __m256 w = _mm256_loadu_ps(&weight[i]);
315  __m256 vel = _mm256_loadu_ps(&velocity[i]);
316 
317  // v = momentum * v + g
318  vel = _mm256_add_ps(_mm256_mul_ps(v_momentum, vel), g);
319 
320  // w = w - lr * (v + weight_decay * w)
321  __m256 update = _mm256_add_ps(vel, _mm256_mul_ps(v_weight_decay, w));
322  w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
323 
324  _mm256_storeu_ps(&weight[i], w);
325  _mm256_storeu_ps(&velocity[i], vel);
326  }
327 
328  for (; i < numel; ++i) {
329  velocity[i] = momentum * velocity[i] + grad[i];
330  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
331  }
332 
333 #elif defined(__SSE2__)
334  // SSE2 path: process 4 floats at a time
335  __m128 v_lr = _mm_set1_ps(lr);
336  __m128 v_momentum = _mm_set1_ps(momentum);
337  __m128 v_weight_decay = _mm_set1_ps(weight_decay);
338 
339  size_t i = 0;
340  for (; i + 4 <= numel; i += 4) {
341  __m128 g = _mm_loadu_ps(&grad[i]);
342  __m128 w = _mm_loadu_ps(&weight[i]);
343  __m128 vel = _mm_loadu_ps(&velocity[i]);
344 
345  vel = _mm_add_ps(_mm_mul_ps(v_momentum, vel), g);
346  __m128 update = _mm_add_ps(vel, _mm_mul_ps(v_weight_decay, w));
347  w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
348 
349  _mm_storeu_ps(&weight[i], w);
350  _mm_storeu_ps(&velocity[i], vel);
351  }
352 
353  for (; i < numel; ++i) {
354  velocity[i] = momentum * velocity[i] + grad[i];
355  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
356  }
357 
358 #else
359  // Scalar path
360  for (size_t i = 0; i < numel; ++i) {
361  velocity[i] = momentum * velocity[i] + grad[i];
362  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
363  }
364 #endif
365 }

◆ zero_gradients_bf16()

void zero_gradients_bf16 ( uint16_t *  grad,
size_t  numel 
)

Zero out gradient buffer (bf16)

Definition at line 217 of file optimizer_kernels_bf16.c.

218 {
219  if (!grad || numel == 0) {
220  return;
221  }
222  memset(grad, 0, numel * sizeof(uint16_t));
223 }