28 const float *grad,
float *weight,
float *m,
float *v,
size_t numel,
29 float lr,
float beta1,
float beta2,
float eps,
float weight_decay,
int step);
32 const float *grad,
float *weight,
float *velocity,
size_t numel,
33 float lr,
float momentum,
float weight_decay);
70 if (!grad || !weight || !m || !v || numel == 0) {
75 float bias_correction1 = 1.0f - powf(beta1, (
float)step);
76 float bias_correction2 = 1.0f - powf(beta2, (
float)step);
77 float one_minus_beta1 = 1.0f - beta1;
78 float one_minus_beta2 = 1.0f - beta2;
80 #if defined(__AVX512F__)
82 __m512 v_beta1 = _mm512_set1_ps(beta1);
83 __m512 v_beta2 = _mm512_set1_ps(beta2);
84 __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
85 __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
86 __m512 v_lr = _mm512_set1_ps(lr);
87 __m512 v_eps = _mm512_set1_ps(eps);
88 __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
89 __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
90 __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
93 for (; i + 16 <= numel; i += 16) {
95 __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
96 __m512 w = bf16_loadu_cvt_fp32(&weight[i]);
99 __m512 m_val = _mm512_loadu_ps(&m[i]);
100 __m512 v_val = _mm512_loadu_ps(&v[i]);
103 m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
106 __m512 g_sq = _mm512_mul_ps(g, g);
107 v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
110 __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
111 __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
114 __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
115 __m512 update = _mm512_div_ps(m_hat, denom);
116 update = _mm512_fmadd_ps(v_weight_decay, w, update);
117 w = _mm512_fnmadd_ps(v_lr, update, w);
120 fp32_cvt_storeu_bf16(&weight[i], w);
123 _mm512_storeu_ps(&m[i], m_val);
124 _mm512_storeu_ps(&v[i], v_val);
128 for (; i < numel; ++i) {
132 m[i] = beta1 * m[i] + one_minus_beta1 * g;
133 v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
135 float m_hat = m[i] / bias_correction1;
136 float v_hat = v[i] / bias_correction2;
138 w = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
143 for (
size_t i = 0; i < numel; ++i) {
147 m[i] = beta1 * m[i] + one_minus_beta1 * g;
148 v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
150 float m_hat = m[i] / bias_correction1;
151 float v_hat = v[i] / bias_correction2;
153 w = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
164 const uint16_t *grad,
172 if (!grad || !weight || !velocity || numel == 0) {
176 #if defined(__AVX512F__)
177 __m512 v_lr = _mm512_set1_ps(lr);
178 __m512 v_momentum = _mm512_set1_ps(momentum);
179 __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
182 for (; i + 16 <= numel; i += 16) {
183 __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
184 __m512 w = bf16_loadu_cvt_fp32(&weight[i]);
185 __m512 vel = _mm512_loadu_ps(&velocity[i]);
187 vel = _mm512_fmadd_ps(v_momentum, vel, g);
188 __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
189 w = _mm512_fnmadd_ps(v_lr, update, w);
191 fp32_cvt_storeu_bf16(&weight[i], w);
192 _mm512_storeu_ps(&velocity[i], vel);
195 for (; i < numel; ++i) {
198 velocity[i] = momentum * velocity[i] + g;
199 w = w - lr * (velocity[i] + weight_decay * w);
203 for (
size_t i = 0; i < numel; ++i) {
206 velocity[i] = momentum * velocity[i] + g;
207 w = w - lr * (velocity[i] + weight_decay * w);
219 if (!grad || numel == 0) {
222 memset(grad, 0, numel *
sizeof(uint16_t));
231 if (!dst || !src || numel == 0) {
235 #if defined(__AVX512F__)
237 for (; i + 16 <= numel; i += 16) {
238 __m512 d = bf16_loadu_cvt_fp32(&dst[i]);
239 __m512 s = bf16_loadu_cvt_fp32(&src[i]);
240 fp32_cvt_storeu_bf16(&dst[i], _mm512_add_ps(d, s));
242 for (; i < numel; ++i) {
248 for (
size_t i = 0; i < numel; ++i) {
262 if (!grad || numel == 0) {
266 #if defined(__AVX512F__)
267 __m512 v_scale = _mm512_set1_ps(scale);
269 for (; i + 16 <= numel; i += 16) {
270 __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
271 fp32_cvt_storeu_bf16(&grad[i], _mm512_mul_ps(g, v_scale));
273 for (; i < numel; ++i) {
278 for (
size_t i = 0; i < numel; ++i) {
293 if (!grad || numel == 0 || max_norm <= 0.0f) {
299 #if defined(__AVX512F__)
300 __m512 acc = _mm512_setzero_ps();
302 for (; i + 16 <= numel; i += 16) {
303 __m512 g = bf16_loadu_cvt_fp32(&grad[i]);
304 acc = _mm512_fmadd_ps(g, g, acc);
306 sum_sq = _mm512_reduce_add_ps(acc);
307 for (; i < numel; ++i) {
309 sum_sq += (double)g * (
double)g;
312 for (
size_t i = 0; i < numel; ++i) {
314 sum_sq += (double)g * (
double)g;
318 float norm = sqrtf((
float)sum_sq);
320 if (norm > max_norm) {
321 float scale = max_norm / norm;
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
void adamw_update_bf16(const uint16_t *grad, uint16_t *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)
void zero_gradients_bf16(uint16_t *grad, size_t numel)
Zero out gradient buffer (bf16)
void sgd_momentum_update_f32(const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum optimizer update (fp32 version)
float gradient_clip_norm_bf16(uint16_t *grad, size_t numel, float max_norm)
Clip gradient norm (bf16)
void gradient_accumulate_bf16(uint16_t *dst, const uint16_t *src, size_t numel)
Accumulate gradients: dst += src (bf16)
void gradient_scale_bf16(uint16_t *grad, size_t numel, float scale)
Scale gradients: grad *= scale (bf16)
void sgd_momentum_update_bf16(const uint16_t *grad, uint16_t *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum (bf16 weights/gradients)
void gradient_scale_f32(float *grad, size_t numel, float scale)
Scale gradients by a constant: grad *= scale (fp32)
void adamw_update_f32(const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (fp32 version)
void gradient_accumulate_f32(float *dst, const float *src, size_t numel)
Accumulate gradients: dst += src (fp32)