Optimizer kernels for training (AdamW, SGD) More...
#include <math.h>#include <stddef.h>#include <stdint.h>#include <string.h>Go to the source code of this file.
Functions | |
| void | adamw_update_f32 (const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step) |
| AdamW optimizer update (fp32 version) More... | |
| void | gradient_accumulate_f32 (float *dst, const float *src, size_t numel) |
| Accumulate gradients: dst += src (fp32) More... | |
| float | gradient_clip_norm_f32 (float *grad, size_t numel, float max_norm) |
| Clip gradient norm (fp32) More... | |
| void | gradient_scale_f32 (float *grad, size_t numel, float scale) |
| Scale gradients by a constant: grad *= scale (fp32) More... | |
| void | sgd_momentum_update_f32 (const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay) |
| SGD with momentum optimizer update (fp32 version) More... | |
| void | zero_gradients_f32 (float *grad, size_t numel) |
| Zero out gradient buffer (fp32) More... | |
Optimizer kernels for training (AdamW, SGD)
After changes: make test && make llamacpp-parity-full
AdamW Algorithm: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 m_hat = m_t / (1 - beta1^t) v_hat = v_t / (1 - beta2^t) w_t = w_{t-1} - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * w_{t-1})
Note: AdamW applies weight decay directly to weights, not to gradients. This is different from L2 regularization (Adam with L2 adds decay to gradient).
Definition in file optimizer_kernels.c.
| void adamw_update_f32 | ( | const float * | grad, |
| float * | weight, | ||
| float * | m, | ||
| float * | v, | ||
| size_t | numel, | ||
| float | lr, | ||
| float | beta1, | ||
| float | beta2, | ||
| float | eps, | ||
| float | weight_decay, | ||
| int | step | ||
| ) |
AdamW optimizer update (fp32 version)
Updates weights in-place using the AdamW algorithm. Momentum (m) and variance (v) are stored in fp32 for numerical stability.
| grad | Gradient tensor (fp32) [numel] |
| weight | Weight tensor to update (fp32, in-place) [numel] |
| m | First moment (momentum) buffer (fp32, in-place) [numel] |
| v | Second moment (variance) buffer (fp32, in-place) [numel] |
| numel | Number of elements |
| lr | Learning rate |
| beta1 | Exponential decay rate for first moment (typically 0.9) |
| beta2 | Exponential decay rate for second moment (typically 0.999) |
| eps | Small constant for numerical stability (typically 1e-8) |
| weight_decay | Weight decay coefficient (typically 0.01) |
| step | Current step number (1-indexed for bias correction) |
Definition at line 53 of file optimizer_kernels.c.
| void gradient_accumulate_f32 | ( | float * | dst, |
| const float * | src, | ||
| size_t | numel | ||
| ) |
Accumulate gradients: dst += src (fp32)
Used for gradient accumulation across micro-batches.
| dst | Destination gradient buffer (in-place) [numel] |
| src | Source gradient buffer [numel] |
| numel | Number of elements |
Definition at line 392 of file optimizer_kernels.c.
| float gradient_clip_norm_f32 | ( | float * | grad, |
| size_t | numel, | ||
| float | max_norm | ||
| ) |
Clip gradient norm (fp32)
If ||grad||_2 > max_norm, scale grad so that ||grad||_2 = max_norm
| grad | Gradient tensor to clip (in-place) [numel] |
| numel | Number of elements |
| max_norm | Maximum allowed L2 norm |
Definition at line 505 of file optimizer_kernels.c.
References gradient_scale_f32().
| void gradient_scale_f32 | ( | float * | grad, |
| size_t | numel, | ||
| float | scale | ||
| ) |
Scale gradients by a constant: grad *= scale (fp32)
Used for averaging gradients after accumulation: grad /= batch_size
| grad | Gradient tensor to scale (in-place) [numel] |
| numel | Number of elements |
| scale | Scale factor (typically 1.0 / batch_size) |
Definition at line 448 of file optimizer_kernels.c.
Referenced by gradient_clip_norm_f32().
| void sgd_momentum_update_f32 | ( | const float * | grad, |
| float * | weight, | ||
| float * | velocity, | ||
| size_t | numel, | ||
| float | lr, | ||
| float | momentum, | ||
| float | weight_decay | ||
| ) |
SGD with momentum optimizer update (fp32 version)
v_t = momentum * v_{t-1} + g_t w_t = w_{t-1} - lr * (v_t + weight_decay * w_{t-1})
| grad | Gradient tensor (fp32) [numel] |
| weight | Weight tensor to update (fp32, in-place) [numel] |
| velocity | Velocity buffer (fp32, in-place) [numel] |
| numel | Number of elements |
| lr | Learning rate |
| momentum | Momentum coefficient (typically 0.9) |
| weight_decay | Weight decay coefficient |
Definition at line 267 of file optimizer_kernels.c.
| void zero_gradients_f32 | ( | float * | grad, |
| size_t | numel | ||
| ) |
Zero out gradient buffer (fp32)
| grad | Gradient tensor to zero [numel] |
| numel | Number of elements |
Definition at line 374 of file optimizer_kernels.c.