← Back to C-Kernel-Engine Docs Doxygen Source Documentation
optimizer_kernels.c File Reference

Optimizer kernels for training (AdamW, SGD) More...

#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>

Go to the source code of this file.

Functions

void adamw_update_f32 (const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
 AdamW optimizer update (fp32 version) More...
 
void gradient_accumulate_f32 (float *dst, const float *src, size_t numel)
 Accumulate gradients: dst += src (fp32) More...
 
float gradient_clip_norm_f32 (float *grad, size_t numel, float max_norm)
 Clip gradient norm (fp32) More...
 
void gradient_scale_f32 (float *grad, size_t numel, float scale)
 Scale gradients by a constant: grad *= scale (fp32) More...
 
void sgd_momentum_update_f32 (const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
 SGD with momentum optimizer update (fp32 version) More...
 
void zero_gradients_f32 (float *grad, size_t numel)
 Zero out gradient buffer (fp32) More...
 

Detailed Description

Optimizer kernels for training (AdamW, SGD)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

AdamW Algorithm: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 m_hat = m_t / (1 - beta1^t) v_hat = v_t / (1 - beta2^t) w_t = w_{t-1} - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w_{t-1})

Note: AdamW applies weight decay directly to weights, not to gradients. This is different from L2 regularization (Adam with L2 adds decay to gradient).

Definition in file optimizer_kernels.c.

Function Documentation

◆ adamw_update_f32()

void adamw_update_f32 ( const float *  grad,
float *  weight,
float *  m,
float *  v,
size_t  numel,
float  lr,
float  beta1,
float  beta2,
float  eps,
float  weight_decay,
int  step 
)

AdamW optimizer update (fp32 version)

Updates weights in-place using the AdamW algorithm. Momentum (m) and variance (v) are stored in fp32 for numerical stability.

Parameters
gradGradient tensor (fp32) [numel]
weightWeight tensor to update (fp32, in-place) [numel]
mFirst moment (momentum) buffer (fp32, in-place) [numel]
vSecond moment (variance) buffer (fp32, in-place) [numel]
numelNumber of elements
lrLearning rate
beta1Exponential decay rate for first moment (typically 0.9)
beta2Exponential decay rate for second moment (typically 0.999)
epsSmall constant for numerical stability (typically 1e-8)
weight_decayWeight decay coefficient (typically 0.01)
stepCurrent step number (1-indexed for bias correction)

Definition at line 53 of file optimizer_kernels.c.

65 {
66  if (!grad || !weight || !m || !v || numel == 0) {
67  return;
68  }
69 
70  // Bias correction terms
71  float bias_correction1 = 1.0f - powf(beta1, (float)step);
72  float bias_correction2 = 1.0f - powf(beta2, (float)step);
73 
74  // Precompute constants
75  float one_minus_beta1 = 1.0f - beta1;
76  float one_minus_beta2 = 1.0f - beta2;
77 
78 #if defined(__AVX512F__)
79  // AVX-512 path: process 16 floats at a time
80  __m512 v_beta1 = _mm512_set1_ps(beta1);
81  __m512 v_beta2 = _mm512_set1_ps(beta2);
82  __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
83  __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
84  __m512 v_lr = _mm512_set1_ps(lr);
85  __m512 v_eps = _mm512_set1_ps(eps);
86  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
87  __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
88  __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
89 
90  size_t i = 0;
91  for (; i + 16 <= numel; i += 16) {
92  __m512 g = _mm512_loadu_ps(&grad[i]);
93  __m512 w = _mm512_loadu_ps(&weight[i]);
94  __m512 m_val = _mm512_loadu_ps(&m[i]);
95  __m512 v_val = _mm512_loadu_ps(&v[i]);
96 
97  // m = beta1 * m + (1 - beta1) * g
98  m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
99 
100  // v = beta2 * v + (1 - beta2) * g^2
101  __m512 g_sq = _mm512_mul_ps(g, g);
102  v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
103 
104  // Bias-corrected estimates
105  __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
106  __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
107 
108  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
109  __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
110  __m512 update = _mm512_div_ps(m_hat, denom);
111  update = _mm512_fmadd_ps(v_weight_decay, w, update);
112  w = _mm512_fnmadd_ps(v_lr, update, w);
113 
114  _mm512_storeu_ps(&weight[i], w);
115  _mm512_storeu_ps(&m[i], m_val);
116  _mm512_storeu_ps(&v[i], v_val);
117  }
118 
119  // Scalar tail
120  for (; i < numel; ++i) {
121  float g = grad[i];
122  float w = weight[i];
123  m[i] = beta1 * m[i] + one_minus_beta1 * g;
124  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
125  float m_hat = m[i] / bias_correction1;
126  float v_hat = v[i] / bias_correction2;
127  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
128  }
129 
130 #elif defined(__AVX__)
131  // AVX path: process 8 floats at a time (no FMA on older CPUs like Ivy Bridge)
132  __m256 v_beta1 = _mm256_set1_ps(beta1);
133  __m256 v_beta2 = _mm256_set1_ps(beta2);
134  __m256 v_one_minus_beta1 = _mm256_set1_ps(one_minus_beta1);
135  __m256 v_one_minus_beta2 = _mm256_set1_ps(one_minus_beta2);
136  __m256 v_lr = _mm256_set1_ps(lr);
137  __m256 v_eps = _mm256_set1_ps(eps);
138  __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
139  __m256 v_bc1_inv = _mm256_set1_ps(1.0f / bias_correction1);
140  __m256 v_bc2_inv = _mm256_set1_ps(1.0f / bias_correction2);
141 
142  size_t i = 0;
143  for (; i + 8 <= numel; i += 8) {
144  __m256 g = _mm256_loadu_ps(&grad[i]);
145  __m256 w = _mm256_loadu_ps(&weight[i]);
146  __m256 m_val = _mm256_loadu_ps(&m[i]);
147  __m256 v_val = _mm256_loadu_ps(&v[i]);
148 
149  // m = beta1 * m + (1 - beta1) * g
150  m_val = _mm256_add_ps(_mm256_mul_ps(v_beta1, m_val),
151  _mm256_mul_ps(v_one_minus_beta1, g));
152 
153  // v = beta2 * v + (1 - beta2) * g^2
154  __m256 g_sq = _mm256_mul_ps(g, g);
155  v_val = _mm256_add_ps(_mm256_mul_ps(v_beta2, v_val),
156  _mm256_mul_ps(v_one_minus_beta2, g_sq));
157 
158  // Bias-corrected estimates
159  __m256 m_hat = _mm256_mul_ps(m_val, v_bc1_inv);
160  __m256 v_hat = _mm256_mul_ps(v_val, v_bc2_inv);
161 
162  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
163  __m256 denom = _mm256_add_ps(_mm256_sqrt_ps(v_hat), v_eps);
164  __m256 update = _mm256_div_ps(m_hat, denom);
165  update = _mm256_add_ps(update, _mm256_mul_ps(v_weight_decay, w));
166  w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
167 
168  _mm256_storeu_ps(&weight[i], w);
169  _mm256_storeu_ps(&m[i], m_val);
170  _mm256_storeu_ps(&v[i], v_val);
171  }
172 
173  // Scalar tail
174  for (; i < numel; ++i) {
175  float g = grad[i];
176  float w = weight[i];
177  m[i] = beta1 * m[i] + one_minus_beta1 * g;
178  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
179  float m_hat = m[i] / bias_correction1;
180  float v_hat = v[i] / bias_correction2;
181  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
182  }
183 
184 #elif defined(__SSE2__)
185  // SSE2 path: process 4 floats at a time
186  __m128 v_beta1 = _mm_set1_ps(beta1);
187  __m128 v_beta2 = _mm_set1_ps(beta2);
188  __m128 v_one_minus_beta1 = _mm_set1_ps(one_minus_beta1);
189  __m128 v_one_minus_beta2 = _mm_set1_ps(one_minus_beta2);
190  __m128 v_lr = _mm_set1_ps(lr);
191  __m128 v_eps = _mm_set1_ps(eps);
192  __m128 v_weight_decay = _mm_set1_ps(weight_decay);
193  __m128 v_bc1_inv = _mm_set1_ps(1.0f / bias_correction1);
194  __m128 v_bc2_inv = _mm_set1_ps(1.0f / bias_correction2);
195 
196  size_t i = 0;
197  for (; i + 4 <= numel; i += 4) {
198  __m128 g = _mm_loadu_ps(&grad[i]);
199  __m128 w = _mm_loadu_ps(&weight[i]);
200  __m128 m_val = _mm_loadu_ps(&m[i]);
201  __m128 v_val = _mm_loadu_ps(&v[i]);
202 
203  // m = beta1 * m + (1 - beta1) * g
204  m_val = _mm_add_ps(_mm_mul_ps(v_beta1, m_val),
205  _mm_mul_ps(v_one_minus_beta1, g));
206 
207  // v = beta2 * v + (1 - beta2) * g^2
208  __m128 g_sq = _mm_mul_ps(g, g);
209  v_val = _mm_add_ps(_mm_mul_ps(v_beta2, v_val),
210  _mm_mul_ps(v_one_minus_beta2, g_sq));
211 
212  // Bias-corrected estimates
213  __m128 m_hat = _mm_mul_ps(m_val, v_bc1_inv);
214  __m128 v_hat = _mm_mul_ps(v_val, v_bc2_inv);
215 
216  // w = w - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w)
217  __m128 denom = _mm_add_ps(_mm_sqrt_ps(v_hat), v_eps);
218  __m128 update = _mm_div_ps(m_hat, denom);
219  update = _mm_add_ps(update, _mm_mul_ps(v_weight_decay, w));
220  w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
221 
222  _mm_storeu_ps(&weight[i], w);
223  _mm_storeu_ps(&m[i], m_val);
224  _mm_storeu_ps(&v[i], v_val);
225  }
226 
227  // Scalar tail
228  for (; i < numel; ++i) {
229  float g = grad[i];
230  float w = weight[i];
231  m[i] = beta1 * m[i] + one_minus_beta1 * g;
232  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
233  float m_hat = m[i] / bias_correction1;
234  float v_hat = v[i] / bias_correction2;
235  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
236  }
237 
238 #else
239  // Scalar path
240  for (size_t i = 0; i < numel; ++i) {
241  float g = grad[i];
242  float w = weight[i];
243  m[i] = beta1 * m[i] + one_minus_beta1 * g;
244  v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
245  float m_hat = m[i] / bias_correction1;
246  float v_hat = v[i] / bias_correction2;
247  weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
248  }
249 #endif
250 }

◆ gradient_accumulate_f32()

void gradient_accumulate_f32 ( float *  dst,
const float *  src,
size_t  numel 
)

Accumulate gradients: dst += src (fp32)

Used for gradient accumulation across micro-batches.

Parameters
dstDestination gradient buffer (in-place) [numel]
srcSource gradient buffer [numel]
numelNumber of elements

Definition at line 392 of file optimizer_kernels.c.

393 {
394  if (!dst || !src || numel == 0) {
395  return;
396  }
397 
398 #if defined(__AVX512F__)
399  size_t i = 0;
400  for (; i + 16 <= numel; i += 16) {
401  __m512 d = _mm512_loadu_ps(&dst[i]);
402  __m512 s = _mm512_loadu_ps(&src[i]);
403  _mm512_storeu_ps(&dst[i], _mm512_add_ps(d, s));
404  }
405  for (; i < numel; ++i) {
406  dst[i] += src[i];
407  }
408 
409 #elif defined(__AVX__)
410  size_t i = 0;
411  for (; i + 8 <= numel; i += 8) {
412  __m256 d = _mm256_loadu_ps(&dst[i]);
413  __m256 s = _mm256_loadu_ps(&src[i]);
414  _mm256_storeu_ps(&dst[i], _mm256_add_ps(d, s));
415  }
416  for (; i < numel; ++i) {
417  dst[i] += src[i];
418  }
419 
420 #elif defined(__SSE2__)
421  size_t i = 0;
422  for (; i + 4 <= numel; i += 4) {
423  __m128 d = _mm_loadu_ps(&dst[i]);
424  __m128 s = _mm_loadu_ps(&src[i]);
425  _mm_storeu_ps(&dst[i], _mm_add_ps(d, s));
426  }
427  for (; i < numel; ++i) {
428  dst[i] += src[i];
429  }
430 
431 #else
432  for (size_t i = 0; i < numel; ++i) {
433  dst[i] += src[i];
434  }
435 #endif
436 }

◆ gradient_clip_norm_f32()

float gradient_clip_norm_f32 ( float *  grad,
size_t  numel,
float  max_norm 
)

Clip gradient norm (fp32)

If ||grad||_2 > max_norm, scale grad so that ||grad||_2 = max_norm

Parameters
gradGradient tensor to clip (in-place) [numel]
numelNumber of elements
max_normMaximum allowed L2 norm
Returns
The original L2 norm before clipping

Definition at line 505 of file optimizer_kernels.c.

506 {
507  if (!grad || numel == 0 || max_norm <= 0.0f) {
508  return 0.0f;
509  }
510 
511  // Compute L2 norm
512  double sum_sq = 0.0;
513 #if defined(__AVX512F__)
514  __m512 acc = _mm512_setzero_ps();
515  size_t i = 0;
516  for (; i + 16 <= numel; i += 16) {
517  __m512 g = _mm512_loadu_ps(&grad[i]);
518  acc = _mm512_fmadd_ps(g, g, acc);
519  }
520  sum_sq = _mm512_reduce_add_ps(acc);
521  for (; i < numel; ++i) {
522  sum_sq += (double)grad[i] * (double)grad[i];
523  }
524 
525 #elif defined(__AVX__)
526  __m256 acc = _mm256_setzero_ps();
527  size_t i = 0;
528  for (; i + 8 <= numel; i += 8) {
529  __m256 g = _mm256_loadu_ps(&grad[i]);
530  acc = _mm256_add_ps(acc, _mm256_mul_ps(g, g));
531  }
532  // Horizontal sum of 8 floats in acc
533  __m128 hi = _mm256_extractf128_ps(acc, 1);
534  __m128 lo = _mm256_castps256_ps128(acc);
535  __m128 sum4 = _mm_add_ps(lo, hi);
536  __m128 shuf = _mm_movehdup_ps(sum4);
537  __m128 sums = _mm_add_ps(sum4, shuf);
538  shuf = _mm_movehl_ps(shuf, sums);
539  sums = _mm_add_ss(sums, shuf);
540  sum_sq = _mm_cvtss_f32(sums);
541  for (; i < numel; ++i) {
542  sum_sq += (double)grad[i] * (double)grad[i];
543  }
544 
545 #elif defined(__SSE2__)
546  __m128 acc = _mm_setzero_ps();
547  size_t i = 0;
548  for (; i + 4 <= numel; i += 4) {
549  __m128 g = _mm_loadu_ps(&grad[i]);
550  acc = _mm_add_ps(acc, _mm_mul_ps(g, g));
551  }
552  // Horizontal sum of 4 floats in acc
553  __m128 shuf = _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(2, 3, 0, 1));
554  __m128 sums = _mm_add_ps(acc, shuf);
555  shuf = _mm_movehl_ps(shuf, sums);
556  sums = _mm_add_ss(sums, shuf);
557  sum_sq = _mm_cvtss_f32(sums);
558  for (; i < numel; ++i) {
559  sum_sq += (double)grad[i] * (double)grad[i];
560  }
561 
562 #else
563  for (size_t i = 0; i < numel; ++i) {
564  sum_sq += (double)grad[i] * (double)grad[i];
565  }
566 #endif
567 
568  float norm = sqrtf((float)sum_sq);
569 
570  // Clip if necessary
571  if (norm > max_norm) {
572  float scale = max_norm / norm;
573  gradient_scale_f32(grad, numel, scale);
574  }
575 
576  return norm;
577 }
void gradient_scale_f32(float *grad, size_t numel, float scale)
Scale gradients by a constant: grad *= scale (fp32)

References gradient_scale_f32().

◆ gradient_scale_f32()

void gradient_scale_f32 ( float *  grad,
size_t  numel,
float  scale 
)

Scale gradients by a constant: grad *= scale (fp32)

Used for averaging gradients after accumulation: grad /= batch_size

Parameters
gradGradient tensor to scale (in-place) [numel]
numelNumber of elements
scaleScale factor (typically 1.0 / batch_size)

Definition at line 448 of file optimizer_kernels.c.

449 {
450  if (!grad || numel == 0) {
451  return;
452  }
453 
454 #if defined(__AVX512F__)
455  __m512 v_scale = _mm512_set1_ps(scale);
456  size_t i = 0;
457  for (; i + 16 <= numel; i += 16) {
458  __m512 g = _mm512_loadu_ps(&grad[i]);
459  _mm512_storeu_ps(&grad[i], _mm512_mul_ps(g, v_scale));
460  }
461  for (; i < numel; ++i) {
462  grad[i] *= scale;
463  }
464 
465 #elif defined(__AVX__)
466  __m256 v_scale = _mm256_set1_ps(scale);
467  size_t i = 0;
468  for (; i + 8 <= numel; i += 8) {
469  __m256 g = _mm256_loadu_ps(&grad[i]);
470  _mm256_storeu_ps(&grad[i], _mm256_mul_ps(g, v_scale));
471  }
472  for (; i < numel; ++i) {
473  grad[i] *= scale;
474  }
475 
476 #elif defined(__SSE2__)
477  __m128 v_scale = _mm_set1_ps(scale);
478  size_t i = 0;
479  for (; i + 4 <= numel; i += 4) {
480  __m128 g = _mm_loadu_ps(&grad[i]);
481  _mm_storeu_ps(&grad[i], _mm_mul_ps(g, v_scale));
482  }
483  for (; i < numel; ++i) {
484  grad[i] *= scale;
485  }
486 
487 #else
488  for (size_t i = 0; i < numel; ++i) {
489  grad[i] *= scale;
490  }
491 #endif
492 }

Referenced by gradient_clip_norm_f32().

◆ sgd_momentum_update_f32()

void sgd_momentum_update_f32 ( const float *  grad,
float *  weight,
float *  velocity,
size_t  numel,
float  lr,
float  momentum,
float  weight_decay 
)

SGD with momentum optimizer update (fp32 version)

v_t = momentum * v_{t-1} + g_t w_t = w_{t-1} - lr * (v_t + weight_decay * w_{t-1})

Parameters
gradGradient tensor (fp32) [numel]
weightWeight tensor to update (fp32, in-place) [numel]
velocityVelocity buffer (fp32, in-place) [numel]
numelNumber of elements
lrLearning rate
momentumMomentum coefficient (typically 0.9)
weight_decayWeight decay coefficient

Definition at line 267 of file optimizer_kernels.c.

275 {
276  if (!grad || !weight || !velocity || numel == 0) {
277  return;
278  }
279 
280 #if defined(__AVX512F__)
281  // AVX-512 path: process 16 floats at a time
282  __m512 v_lr = _mm512_set1_ps(lr);
283  __m512 v_momentum = _mm512_set1_ps(momentum);
284  __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
285 
286  size_t i = 0;
287  for (; i + 16 <= numel; i += 16) {
288  __m512 g = _mm512_loadu_ps(&grad[i]);
289  __m512 w = _mm512_loadu_ps(&weight[i]);
290  __m512 vel = _mm512_loadu_ps(&velocity[i]);
291 
292  vel = _mm512_fmadd_ps(v_momentum, vel, g);
293  __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
294  w = _mm512_fnmadd_ps(v_lr, update, w);
295 
296  _mm512_storeu_ps(&weight[i], w);
297  _mm512_storeu_ps(&velocity[i], vel);
298  }
299 
300  for (; i < numel; ++i) {
301  velocity[i] = momentum * velocity[i] + grad[i];
302  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
303  }
304 
305 #elif defined(__AVX__)
306  // AVX path: process 8 floats at a time
307  __m256 v_lr = _mm256_set1_ps(lr);
308  __m256 v_momentum = _mm256_set1_ps(momentum);
309  __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
310 
311  size_t i = 0;
312  for (; i + 8 <= numel; i += 8) {
313  __m256 g = _mm256_loadu_ps(&grad[i]);
314  __m256 w = _mm256_loadu_ps(&weight[i]);
315  __m256 vel = _mm256_loadu_ps(&velocity[i]);
316 
317  // v = momentum * v + g
318  vel = _mm256_add_ps(_mm256_mul_ps(v_momentum, vel), g);
319 
320  // w = w - lr * (v + weight_decay * w)
321  __m256 update = _mm256_add_ps(vel, _mm256_mul_ps(v_weight_decay, w));
322  w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
323 
324  _mm256_storeu_ps(&weight[i], w);
325  _mm256_storeu_ps(&velocity[i], vel);
326  }
327 
328  for (; i < numel; ++i) {
329  velocity[i] = momentum * velocity[i] + grad[i];
330  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
331  }
332 
333 #elif defined(__SSE2__)
334  // SSE2 path: process 4 floats at a time
335  __m128 v_lr = _mm_set1_ps(lr);
336  __m128 v_momentum = _mm_set1_ps(momentum);
337  __m128 v_weight_decay = _mm_set1_ps(weight_decay);
338 
339  size_t i = 0;
340  for (; i + 4 <= numel; i += 4) {
341  __m128 g = _mm_loadu_ps(&grad[i]);
342  __m128 w = _mm_loadu_ps(&weight[i]);
343  __m128 vel = _mm_loadu_ps(&velocity[i]);
344 
345  vel = _mm_add_ps(_mm_mul_ps(v_momentum, vel), g);
346  __m128 update = _mm_add_ps(vel, _mm_mul_ps(v_weight_decay, w));
347  w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
348 
349  _mm_storeu_ps(&weight[i], w);
350  _mm_storeu_ps(&velocity[i], vel);
351  }
352 
353  for (; i < numel; ++i) {
354  velocity[i] = momentum * velocity[i] + grad[i];
355  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
356  }
357 
358 #else
359  // Scalar path
360  for (size_t i = 0; i < numel; ++i) {
361  velocity[i] = momentum * velocity[i] + grad[i];
362  weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
363  }
364 #endif
365 }

◆ zero_gradients_f32()

void zero_gradients_f32 ( float *  grad,
size_t  numel 
)

Zero out gradient buffer (fp32)

Parameters
gradGradient tensor to zero [numel]
numelNumber of elements

Definition at line 374 of file optimizer_kernels.c.

375 {
376  if (!grad || numel == 0) {
377  return;
378  }
379  memset(grad, 0, numel * sizeof(float));
380 }