BF16 optimizer kernels for training. More...
#include <math.h>#include <stddef.h>#include <stdint.h>#include <stdlib.h>#include <string.h>#include "bf16_utils.h"Go to the source code of this file.
Functions | |
| void | adamw_update_bf16 (const uint16_t *grad, uint16_t *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step) |
| AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state) More... | |
| void | adamw_update_f32 (const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step) |
| AdamW optimizer update (fp32 version) More... | |
| void | gradient_accumulate_bf16 (uint16_t *dst, const uint16_t *src, size_t numel) |
| Accumulate gradients: dst += src (bf16) More... | |
| void | gradient_accumulate_f32 (float *dst, const float *src, size_t numel) |
| Accumulate gradients: dst += src (fp32) More... | |
| float | gradient_clip_norm_bf16 (uint16_t *grad, size_t numel, float max_norm) |
| Clip gradient norm (bf16) More... | |
| void | gradient_scale_bf16 (uint16_t *grad, size_t numel, float scale) |
| Scale gradients: grad *= scale (bf16) More... | |
| void | gradient_scale_f32 (float *grad, size_t numel, float scale) |
| Scale gradients by a constant: grad *= scale (fp32) More... | |
| void | sgd_momentum_update_bf16 (const uint16_t *grad, uint16_t *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay) |
| SGD with momentum (bf16 weights/gradients) More... | |
| void | sgd_momentum_update_f32 (const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay) |
| SGD with momentum optimizer update (fp32 version) More... | |
| void | zero_gradients_bf16 (uint16_t *grad, size_t numel) |
| Zero out gradient buffer (bf16) More... | |
BF16 optimizer kernels for training.
After changes: make test && make llamacpp-parity-full
Note: Optimizer state (m, v) is always kept in fp32 for numerical stability. Only weights and gradients are in bf16.
Definition in file optimizer_kernels_bf16.c.
| void adamw_update_bf16 | ( | const uint16_t * | grad, |
| uint16_t * | weight, | ||
| float * | m, | ||
| float * | v, | ||
| size_t | numel, | ||
| float | lr, | ||
| float | beta1, | ||
| float | beta2, | ||
| float | eps, | ||
| float | weight_decay, | ||
| int | step | ||
| ) |
AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)
Weights and gradients are in bf16 for memory efficiency. Momentum (m) and variance (v) are in fp32 for numerical stability.
| grad | Gradient tensor (bf16) [numel] |
| weight | Weight tensor to update (bf16, in-place) [numel] |
| m | First moment buffer (fp32, in-place) [numel] |
| v | Second moment buffer (fp32, in-place) [numel] |
| numel | Number of elements |
| lr | Learning rate |
| beta1 | First moment decay (typically 0.9) |
| beta2 | Second moment decay (typically 0.999) |
| eps | Numerical stability constant (typically 1e-8) |
| weight_decay | Weight decay coefficient |
| step | Current step number (1-indexed) |
Definition at line 57 of file optimizer_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
| void adamw_update_f32 | ( | const float * | grad, |
| float * | weight, | ||
| float * | m, | ||
| float * | v, | ||
| size_t | numel, | ||
| float | lr, | ||
| float | beta1, | ||
| float | beta2, | ||
| float | eps, | ||
| float | weight_decay, | ||
| int | step | ||
| ) |
AdamW optimizer update (fp32 version)
Updates weights in-place using the AdamW algorithm. Momentum (m) and variance (v) are stored in fp32 for numerical stability.
| grad | Gradient tensor (fp32) [numel] |
| weight | Weight tensor to update (fp32, in-place) [numel] |
| m | First moment (momentum) buffer (fp32, in-place) [numel] |
| v | Second moment (variance) buffer (fp32, in-place) [numel] |
| numel | Number of elements |
| lr | Learning rate |
| beta1 | Exponential decay rate for first moment (typically 0.9) |
| beta2 | Exponential decay rate for second moment (typically 0.999) |
| eps | Small constant for numerical stability (typically 1e-8) |
| weight_decay | Weight decay coefficient (typically 0.01) |
| step | Current step number (1-indexed for bias correction) |
Definition at line 53 of file optimizer_kernels.c.
| void gradient_accumulate_bf16 | ( | uint16_t * | dst, |
| const uint16_t * | src, | ||
| size_t | numel | ||
| ) |
Accumulate gradients: dst += src (bf16)
Definition at line 229 of file optimizer_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
| void gradient_accumulate_f32 | ( | float * | dst, |
| const float * | src, | ||
| size_t | numel | ||
| ) |
Accumulate gradients: dst += src (fp32)
Used for gradient accumulation across micro-batches.
| dst | Destination gradient buffer (in-place) [numel] |
| src | Source gradient buffer [numel] |
| numel | Number of elements |
Definition at line 392 of file optimizer_kernels.c.
| float gradient_clip_norm_bf16 | ( | uint16_t * | grad, |
| size_t | numel, | ||
| float | max_norm | ||
| ) |
Clip gradient norm (bf16)
Definition at line 291 of file optimizer_kernels_bf16.c.
References bf16_to_float(), and gradient_scale_bf16().
| void gradient_scale_bf16 | ( | uint16_t * | grad, |
| size_t | numel, | ||
| float | scale | ||
| ) |
Scale gradients: grad *= scale (bf16)
Definition at line 260 of file optimizer_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
Referenced by gradient_clip_norm_bf16().
| void gradient_scale_f32 | ( | float * | grad, |
| size_t | numel, | ||
| float | scale | ||
| ) |
Scale gradients by a constant: grad *= scale (fp32)
Used for averaging gradients after accumulation: grad /= batch_size
| grad | Gradient tensor to scale (in-place) [numel] |
| numel | Number of elements |
| scale | Scale factor (typically 1.0 / batch_size) |
Definition at line 448 of file optimizer_kernels.c.
Referenced by gradient_clip_norm_f32().
| void sgd_momentum_update_bf16 | ( | const uint16_t * | grad, |
| uint16_t * | weight, | ||
| float * | velocity, | ||
| size_t | numel, | ||
| float | lr, | ||
| float | momentum, | ||
| float | weight_decay | ||
| ) |
SGD with momentum (bf16 weights/gradients)
Definition at line 163 of file optimizer_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
| void sgd_momentum_update_f32 | ( | const float * | grad, |
| float * | weight, | ||
| float * | velocity, | ||
| size_t | numel, | ||
| float | lr, | ||
| float | momentum, | ||
| float | weight_decay | ||
| ) |
SGD with momentum optimizer update (fp32 version)
v_t = momentum * v_{t-1} + g_t w_t = w_{t-1} - lr * (v_t + weight_decay * w_{t-1})
| grad | Gradient tensor (fp32) [numel] |
| weight | Weight tensor to update (fp32, in-place) [numel] |
| velocity | Velocity buffer (fp32, in-place) [numel] |
| numel | Number of elements |
| lr | Learning rate |
| momentum | Momentum coefficient (typically 0.9) |
| weight_decay | Weight decay coefficient |
Definition at line 267 of file optimizer_kernels.c.
| void zero_gradients_bf16 | ( | uint16_t * | grad, |
| size_t | numel | ||
| ) |
Zero out gradient buffer (bf16)
Definition at line 217 of file optimizer_kernels_bf16.c.