← Back to C-Kernel-Engine Docs Doxygen Source Documentation
optimizer_kernels.c
Go to the documentation of this file.
1 /**
2  * @file optimizer_kernels.c
3  * @brief Optimizer kernels for training (AdamW, SGD)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * AdamW Algorithm:
15  * m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
16  * v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
17  * m_hat = m_t / (1 - beta1^t)
18  * v_hat = v_t / (1 - beta2^t)
19  * w_t = w_{t-1} - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w_{t-1})
20  *
21  * Note: AdamW applies weight decay directly to weights, not to gradients.
22  * This is different from L2 regularization (Adam with L2 adds decay to gradient).
23  */
24 
25 #include <math.h>
26 #include <stddef.h>
27 #include <stdint.h>
28 #include <string.h>
29 
30 /* Include SIMD headers based on available instruction sets */
31 #if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE2__)
32 #include <immintrin.h>
33 #endif
34 
35 /**
36  * @brief AdamW optimizer update (fp32 version)
37  *
38  * Updates weights in-place using the AdamW algorithm.
39  * Momentum (m) and variance (v) are stored in fp32 for numerical stability.
40  *
41  * @param grad Gradient tensor (fp32) [numel]
42  * @param weight Weight tensor to update (fp32, in-place) [numel]
43  * @param m First moment (momentum) buffer (fp32, in-place) [numel]
44  * @param v Second moment (variance) buffer (fp32, in-place) [numel]
45  * @param numel Number of elements
46  * @param lr Learning rate
47  * @param beta1 Exponential decay rate for first moment (typically 0.9)
48  * @param beta2 Exponential decay rate for second moment (typically 0.999)
49  * @param eps Small constant for numerical stability (typically 1e-8)
50  * @param weight_decay Weight decay coefficient (typically 0.01)
51  * @param step Current step number (1-indexed for bias correction)
52  */
54  const float *grad,
55  float *weight,
56  float *m,
57  float *v,
58  size_t numel,
59  float lr,
60  float beta1,
61  float beta2,
62  float eps,
63  float weight_decay,
64  int step)
65 {
66  if (!grad || !weight || !m || !v || numel == 0) {
67  return;
68  }
69 
70  // Bias correction terms
71  float bias_correction1 = 1.0f - powf(beta1, (float)step);
72  float bias_correction2 = 1.0f - powf(beta2, (float)step);
73 
74  // Precompute constants
75  float one_minus_beta1 = 1.0f - beta1;
76  float one_minus_beta2 = 1.0f - beta2;
77 
78 #if defined(__AVX512F__)
79  // AVX-512 path: process 16 floats at a time
80  __m512 v_beta1 = _mm512_set1_ps(beta1);
81  __m512 v_beta2 = _mm512_set1_ps(beta2);
82  __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
83  __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
84  __m512 v_lr = _mm512_set1_ps(lr);
85  __m512 v_eps = _mm512_set1_ps(eps);
86  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
87  __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
88  __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
89 
90  size_t i = 0;
91  for (; i + 16 <= numel; i += 16) {
92  __m512 g = _mm512_loadu_ps(&grad[i]);
93  __m512 w = _mm512_loadu_ps(&weight[i]);
94  __m512 m_val = _mm512_loadu_ps(&m[i]);
95  __m512 v_val = _mm512_loadu_ps(&v[i]);
96 
97  // m = beta1 * m + (1 - beta1) * g
98  m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
99 
100  // v = beta2 * v + (1 - beta2) * g^2
101  __m512 g_sq = _mm512_mul_ps(g, g);
102  v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
103 
104  // Bias-corrected estimates
105  __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
106  __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
107 
108  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
109  __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
110  __m512 update = _mm512_div_ps(m_hat, denom);
111  update = _mm512_fmadd_ps(v_weight_decay, w, update);
112  w = _mm512_fnmadd_ps(v_lr, update, w);
113 
114  _mm512_storeu_ps(&weight[i], w);
115  _mm512_storeu_ps(&m[i], m_val);
116  _mm512_storeu_ps(&v[i], v_val);
117  }
118 
119  // Scalar tail
120  for (; i < numel; ++i) {
121  float g = grad[i];
122  float w = weight[i];
123  m[i] = beta1 * m[i] + one_minus_beta1 * g;
124  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
125  float m_hat = m[i] / bias_correction1;
126  float v_hat = v[i] / bias_correction2;
127  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
128  }
129 
130 #elif defined(__AVX__)
131  // AVX path: process 8 floats at a time (no FMA on older CPUs like Ivy Bridge)
132  __m256 v_beta1 = _mm256_set1_ps(beta1);
133  __m256 v_beta2 = _mm256_set1_ps(beta2);
134  __m256 v_one_minus_beta1 = _mm256_set1_ps(one_minus_beta1);
135  __m256 v_one_minus_beta2 = _mm256_set1_ps(one_minus_beta2);
136  __m256 v_lr = _mm256_set1_ps(lr);
137  __m256 v_eps = _mm256_set1_ps(eps);
138  __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
139  __m256 v_bc1_inv = _mm256_set1_ps(1.0f / bias_correction1);
140  __m256 v_bc2_inv = _mm256_set1_ps(1.0f / bias_correction2);
141 
142  size_t i = 0;
143  for (; i + 8 <= numel; i += 8) {
144  __m256 g = _mm256_loadu_ps(&grad[i]);
145  __m256 w = _mm256_loadu_ps(&weight[i]);
146  __m256 m_val = _mm256_loadu_ps(&m[i]);
147  __m256 v_val = _mm256_loadu_ps(&v[i]);
148 
149  // m = beta1 * m + (1 - beta1) * g
150  m_val = _mm256_add_ps(_mm256_mul_ps(v_beta1, m_val),
151  _mm256_mul_ps(v_one_minus_beta1, g));
152 
153  // v = beta2 * v + (1 - beta2) * g^2
154  __m256 g_sq = _mm256_mul_ps(g, g);
155  v_val = _mm256_add_ps(_mm256_mul_ps(v_beta2, v_val),
156  _mm256_mul_ps(v_one_minus_beta2, g_sq));
157 
158  // Bias-corrected estimates
159  __m256 m_hat = _mm256_mul_ps(m_val, v_bc1_inv);
160  __m256 v_hat = _mm256_mul_ps(v_val, v_bc2_inv);
161 
162  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
163  __m256 denom = _mm256_add_ps(_mm256_sqrt_ps(v_hat), v_eps);
164  __m256 update = _mm256_div_ps(m_hat, denom);
165  update = _mm256_add_ps(update, _mm256_mul_ps(v_weight_decay, w));
166  w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
167 
168  _mm256_storeu_ps(&weight[i], w);
169  _mm256_storeu_ps(&m[i], m_val);
170  _mm256_storeu_ps(&v[i], v_val);
171  }
172 
173  // Scalar tail
174  for (; i < numel; ++i) {
175  float g = grad[i];
176  float w = weight[i];
177  m[i] = beta1 * m[i] + one_minus_beta1 * g;
178  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
179  float m_hat = m[i] / bias_correction1;
180  float v_hat = v[i] / bias_correction2;
181  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
182  }
183 
184 #elif defined(__SSE2__)
185  // SSE2 path: process 4 floats at a time
186  __m128 v_beta1 = _mm_set1_ps(beta1);
187  __m128 v_beta2 = _mm_set1_ps(beta2);
188  __m128 v_one_minus_beta1 = _mm_set1_ps(one_minus_beta1);
189  __m128 v_one_minus_beta2 = _mm_set1_ps(one_minus_beta2);
190  __m128 v_lr = _mm_set1_ps(lr);
191  __m128 v_eps = _mm_set1_ps(eps);
192  __m128 v_weight_decay = _mm_set1_ps(weight_decay);
193  __m128 v_bc1_inv = _mm_set1_ps(1.0f / bias_correction1);
194  __m128 v_bc2_inv = _mm_set1_ps(1.0f / bias_correction2);
195 
196  size_t i = 0;
197  for (; i + 4 <= numel; i += 4) {
198  __m128 g = _mm_loadu_ps(&grad[i]);
199  __m128 w = _mm_loadu_ps(&weight[i]);
200  __m128 m_val = _mm_loadu_ps(&m[i]);
201  __m128 v_val = _mm_loadu_ps(&v[i]);
202 
203  // m = beta1 * m + (1 - beta1) * g
204  m_val = _mm_add_ps(_mm_mul_ps(v_beta1, m_val),
205  _mm_mul_ps(v_one_minus_beta1, g));
206 
207  // v = beta2 * v + (1 - beta2) * g^2
208  __m128 g_sq = _mm_mul_ps(g, g);
209  v_val = _mm_add_ps(_mm_mul_ps(v_beta2, v_val),
210  _mm_mul_ps(v_one_minus_beta2, g_sq));
211 
212  // Bias-corrected estimates
213  __m128 m_hat = _mm_mul_ps(m_val, v_bc1_inv);
214  __m128 v_hat = _mm_mul_ps(v_val, v_bc2_inv);
215 
216  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
217  __m128 denom = _mm_add_ps(_mm_sqrt_ps(v_hat), v_eps);
218  __m128 update = _mm_div_ps(m_hat, denom);
219  update = _mm_add_ps(update, _mm_mul_ps(v_weight_decay, w));
220  w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
221 
222  _mm_storeu_ps(&weight[i], w);
223  _mm_storeu_ps(&m[i], m_val);
224  _mm_storeu_ps(&v[i], v_val);
225  }
226 
227  // Scalar tail
228  for (; i < numel; ++i) {
229  float g = grad[i];
230  float w = weight[i];
231  m[i] = beta1 * m[i] + one_minus_beta1 * g;
232  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
233  float m_hat = m[i] / bias_correction1;
234  float v_hat = v[i] / bias_correction2;
235  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
236  }
237 
238 #else
239  // Scalar path
240  for (size_t i = 0; i < numel; ++i) {
241  float g = grad[i];
242  float w = weight[i];
243  m[i] = beta1 * m[i] + one_minus_beta1 * g;
244  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
245  float m_hat = m[i] / bias_correction1;
246  float v_hat = v[i] / bias_correction2;
247  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
248  }
249 #endif
250 }
251 
252 
253 /**
254  * @brief SGD with momentum optimizer update (fp32 version)
255  *
256  * v_t = momentum * v_{t-1} + g_t
257  * w_t = w_{t-1} - lr * (v_t + weight_decay * w_{t-1})
258  *
259  * @param grad Gradient tensor (fp32) [numel]
260  * @param weight Weight tensor to update (fp32, in-place) [numel]
261  * @param velocity Velocity buffer (fp32, in-place) [numel]
262  * @param numel Number of elements
263  * @param lr Learning rate
264  * @param momentum Momentum coefficient (typically 0.9)
265  * @param weight_decay Weight decay coefficient
266  */
268  const float *grad,
269  float *weight,
270  float *velocity,
271  size_t numel,
272  float lr,
273  float momentum,
274  float weight_decay)
275 {
276  if (!grad || !weight || !velocity || numel == 0) {
277  return;
278  }
279 
280 #if defined(__AVX512F__)
281  // AVX-512 path: process 16 floats at a time
282  __m512 v_lr = _mm512_set1_ps(lr);
283  __m512 v_momentum = _mm512_set1_ps(momentum);
284  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
285 
286  size_t i = 0;
287  for (; i + 16 <= numel; i += 16) {
288  __m512 g = _mm512_loadu_ps(&grad[i]);
289  __m512 w = _mm512_loadu_ps(&weight[i]);
290  __m512 vel = _mm512_loadu_ps(&velocity[i]);
291 
292  vel = _mm512_fmadd_ps(v_momentum, vel, g);
293  __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
294  w = _mm512_fnmadd_ps(v_lr, update, w);
295 
296  _mm512_storeu_ps(&weight[i], w);
297  _mm512_storeu_ps(&velocity[i], vel);
298  }
299 
300  for (; i < numel; ++i) {
301  velocity[i] = momentum * velocity[i] + grad[i];
302  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
303  }
304 
305 #elif defined(__AVX__)
306  // AVX path: process 8 floats at a time
307  __m256 v_lr = _mm256_set1_ps(lr);
308  __m256 v_momentum = _mm256_set1_ps(momentum);
309  __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
310 
311  size_t i = 0;
312  for (; i + 8 <= numel; i += 8) {
313  __m256 g = _mm256_loadu_ps(&grad[i]);
314  __m256 w = _mm256_loadu_ps(&weight[i]);
315  __m256 vel = _mm256_loadu_ps(&velocity[i]);
316 
317  // v = momentum * v + g
318  vel = _mm256_add_ps(_mm256_mul_ps(v_momentum, vel), g);
319 
320  // w = w - lr * (v + weight_decay * w)
321  __m256 update = _mm256_add_ps(vel, _mm256_mul_ps(v_weight_decay, w));
322  w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
323 
324  _mm256_storeu_ps(&weight[i], w);
325  _mm256_storeu_ps(&velocity[i], vel);
326  }
327 
328  for (; i < numel; ++i) {
329  velocity[i] = momentum * velocity[i] + grad[i];
330  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
331  }
332 
333 #elif defined(__SSE2__)
334  // SSE2 path: process 4 floats at a time
335  __m128 v_lr = _mm_set1_ps(lr);
336  __m128 v_momentum = _mm_set1_ps(momentum);
337  __m128 v_weight_decay = _mm_set1_ps(weight_decay);
338 
339  size_t i = 0;
340  for (; i + 4 <= numel; i += 4) {
341  __m128 g = _mm_loadu_ps(&grad[i]);
342  __m128 w = _mm_loadu_ps(&weight[i]);
343  __m128 vel = _mm_loadu_ps(&velocity[i]);
344 
345  vel = _mm_add_ps(_mm_mul_ps(v_momentum, vel), g);
346  __m128 update = _mm_add_ps(vel, _mm_mul_ps(v_weight_decay, w));
347  w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
348 
349  _mm_storeu_ps(&weight[i], w);
350  _mm_storeu_ps(&velocity[i], vel);
351  }
352 
353  for (; i < numel; ++i) {
354  velocity[i] = momentum * velocity[i] + grad[i];
355  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
356  }
357 
358 #else
359  // Scalar path
360  for (size_t i = 0; i < numel; ++i) {
361  velocity[i] = momentum * velocity[i] + grad[i];
362  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
363  }
364 #endif
365 }
366 
367 
368 /**
369  * @brief Zero out gradient buffer (fp32)
370  *
371  * @param grad Gradient tensor to zero [numel]
372  * @param numel Number of elements
373  */
374 void zero_gradients_f32(float *grad, size_t numel)
375 {
376  if (!grad || numel == 0) {
377  return;
378  }
379  memset(grad, 0, numel * sizeof(float));
380 }
381 
382 
383 /**
384  * @brief Accumulate gradients: dst += src (fp32)
385  *
386  * Used for gradient accumulation across micro-batches.
387  *
388  * @param dst Destination gradient buffer (in-place) [numel]
389  * @param src Source gradient buffer [numel]
390  * @param numel Number of elements
391  */
392 void gradient_accumulate_f32(float *dst, const float *src, size_t numel)
393 {
394  if (!dst || !src || numel == 0) {
395  return;
396  }
397 
398 #if defined(__AVX512F__)
399  size_t i = 0;
400  for (; i + 16 <= numel; i += 16) {
401  __m512 d = _mm512_loadu_ps(&dst[i]);
402  __m512 s = _mm512_loadu_ps(&src[i]);
403  _mm512_storeu_ps(&dst[i], _mm512_add_ps(d, s));
404  }
405  for (; i < numel; ++i) {
406  dst[i] += src[i];
407  }
408 
409 #elif defined(__AVX__)
410  size_t i = 0;
411  for (; i + 8 <= numel; i += 8) {
412  __m256 d = _mm256_loadu_ps(&dst[i]);
413  __m256 s = _mm256_loadu_ps(&src[i]);
414  _mm256_storeu_ps(&dst[i], _mm256_add_ps(d, s));
415  }
416  for (; i < numel; ++i) {
417  dst[i] += src[i];
418  }
419 
420 #elif defined(__SSE2__)
421  size_t i = 0;
422  for (; i + 4 <= numel; i += 4) {
423  __m128 d = _mm_loadu_ps(&dst[i]);
424  __m128 s = _mm_loadu_ps(&src[i]);
425  _mm_storeu_ps(&dst[i], _mm_add_ps(d, s));
426  }
427  for (; i < numel; ++i) {
428  dst[i] += src[i];
429  }
430 
431 #else
432  for (size_t i = 0; i < numel; ++i) {
433  dst[i] += src[i];
434  }
435 #endif
436 }
437 
438 
439 /**
440  * @brief Scale gradients by a constant: grad *= scale (fp32)
441  *
442  * Used for averaging gradients after accumulation: grad /= batch_size
443  *
444  * @param grad Gradient tensor to scale (in-place) [numel]
445  * @param numel Number of elements
446  * @param scale Scale factor (typically 1.0 / batch_size)
447  */
448 void gradient_scale_f32(float *grad, size_t numel, float scale)
449 {
450  if (!grad || numel == 0) {
451  return;
452  }
453 
454 #if defined(__AVX512F__)
455  __m512 v_scale = _mm512_set1_ps(scale);
456  size_t i = 0;
457  for (; i + 16 <= numel; i += 16) {
458  __m512 g = _mm512_loadu_ps(&grad[i]);
459  _mm512_storeu_ps(&grad[i], _mm512_mul_ps(g, v_scale));
460  }
461  for (; i < numel; ++i) {
462  grad[i] *= scale;
463  }
464 
465 #elif defined(__AVX__)
466  __m256 v_scale = _mm256_set1_ps(scale);
467  size_t i = 0;
468  for (; i + 8 <= numel; i += 8) {
469  __m256 g = _mm256_loadu_ps(&grad[i]);
470  _mm256_storeu_ps(&grad[i], _mm256_mul_ps(g, v_scale));
471  }
472  for (; i < numel; ++i) {
473  grad[i] *= scale;
474  }
475 
476 #elif defined(__SSE2__)
477  __m128 v_scale = _mm_set1_ps(scale);
478  size_t i = 0;
479  for (; i + 4 <= numel; i += 4) {
480  __m128 g = _mm_loadu_ps(&grad[i]);
481  _mm_storeu_ps(&grad[i], _mm_mul_ps(g, v_scale));
482  }
483  for (; i < numel; ++i) {
484  grad[i] *= scale;
485  }
486 
487 #else
488  for (size_t i = 0; i < numel; ++i) {
489  grad[i] *= scale;
490  }
491 #endif
492 }
493 
494 
495 /**
496  * @brief Clip gradient norm (fp32)
497  *
498  * If ||grad||_2 > max_norm, scale grad so that ||grad||_2 = max_norm
499  *
500  * @param grad Gradient tensor to clip (in-place) [numel]
501  * @param numel Number of elements
502  * @param max_norm Maximum allowed L2 norm
503  * @return The original L2 norm before clipping
504  */
505 float gradient_clip_norm_f32(float *grad, size_t numel, float max_norm)
506 {
507  if (!grad || numel == 0 || max_norm <= 0.0f) {
508  return 0.0f;
509  }
510 
511  // Compute L2 norm
512  double sum_sq = 0.0;
513 #if defined(__AVX512F__)
514  __m512 acc = _mm512_setzero_ps();
515  size_t i = 0;
516  for (; i + 16 <= numel; i += 16) {
517  __m512 g = _mm512_loadu_ps(&grad[i]);
518  acc = _mm512_fmadd_ps(g, g, acc);
519  }
520  sum_sq = _mm512_reduce_add_ps(acc);
521  for (; i < numel; ++i) {
522  sum_sq += (double)grad[i] * (double)grad[i];
523  }
524 
525 #elif defined(__AVX__)
526  __m256 acc = _mm256_setzero_ps();
527  size_t i = 0;
528  for (; i + 8 <= numel; i += 8) {
529  __m256 g = _mm256_loadu_ps(&grad[i]);
530  acc = _mm256_add_ps(acc, _mm256_mul_ps(g, g));
531  }
532  // Horizontal sum of 8 floats in acc
533  __m128 hi = _mm256_extractf128_ps(acc, 1);
534  __m128 lo = _mm256_castps256_ps128(acc);
535  __m128 sum4 = _mm_add_ps(lo, hi);
536  __m128 shuf = _mm_movehdup_ps(sum4);
537  __m128 sums = _mm_add_ps(sum4, shuf);
538  shuf = _mm_movehl_ps(shuf, sums);
539  sums = _mm_add_ss(sums, shuf);
540  sum_sq = _mm_cvtss_f32(sums);
541  for (; i < numel; ++i) {
542  sum_sq += (double)grad[i] * (double)grad[i];
543  }
544 
545 #elif defined(__SSE2__)
546  __m128 acc = _mm_setzero_ps();
547  size_t i = 0;
548  for (; i + 4 <= numel; i += 4) {
549  __m128 g = _mm_loadu_ps(&grad[i]);
550  acc = _mm_add_ps(acc, _mm_mul_ps(g, g));
551  }
552  // Horizontal sum of 4 floats in acc
553  __m128 shuf = _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(2, 3, 0, 1));
554  __m128 sums = _mm_add_ps(acc, shuf);
555  shuf = _mm_movehl_ps(shuf, sums);
556  sums = _mm_add_ss(sums, shuf);
557  sum_sq = _mm_cvtss_f32(sums);
558  for (; i < numel; ++i) {
559  sum_sq += (double)grad[i] * (double)grad[i];
560  }
561 
562 #else
563  for (size_t i = 0; i < numel; ++i) {
564  sum_sq += (double)grad[i] * (double)grad[i];
565  }
566 #endif
567 
568  float norm = sqrtf((float)sum_sq);
569 
570  // Clip if necessary
571  if (norm > max_norm) {
572  float scale = max_norm / norm;
573  gradient_scale_f32(grad, numel, scale);
574  }
575 
576  return norm;
577 }
void sgd_momentum_update_f32(const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum optimizer update (fp32 version)
float gradient_clip_norm_f32(float *grad, size_t numel, float max_norm)
Clip gradient norm (fp32)
void gradient_scale_f32(float *grad, size_t numel, float scale)
Scale gradients by a constant: grad *= scale (fp32)
void adamw_update_f32(const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (fp32 version)
void gradient_accumulate_f32(float *dst, const float *src, size_t numel)
Accumulate gradients: dst += src (fp32)
void zero_gradients_f32(float *grad, size_t numel)
Zero out gradient buffer (fp32)