31 #if defined(__AVX512F__) || defined(__AVX__) || defined(__SSE2__)
32 #include <immintrin.h>
66 if (!grad || !weight || !m || !v || numel == 0) {
71 float bias_correction1 = 1.0f - powf(beta1, (
float)step);
72 float bias_correction2 = 1.0f - powf(beta2, (
float)step);
75 float one_minus_beta1 = 1.0f - beta1;
76 float one_minus_beta2 = 1.0f - beta2;
78 #if defined(__AVX512F__)
80 __m512 v_beta1 = _mm512_set1_ps(beta1);
81 __m512 v_beta2 = _mm512_set1_ps(beta2);
82 __m512 v_one_minus_beta1 = _mm512_set1_ps(one_minus_beta1);
83 __m512 v_one_minus_beta2 = _mm512_set1_ps(one_minus_beta2);
84 __m512 v_lr = _mm512_set1_ps(lr);
85 __m512 v_eps = _mm512_set1_ps(eps);
86 __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
87 __m512 v_bc1_inv = _mm512_set1_ps(1.0f / bias_correction1);
88 __m512 v_bc2_inv = _mm512_set1_ps(1.0f / bias_correction2);
91 for (; i + 16 <= numel; i += 16) {
92 __m512 g = _mm512_loadu_ps(&grad[i]);
93 __m512 w = _mm512_loadu_ps(&weight[i]);
94 __m512 m_val = _mm512_loadu_ps(&m[i]);
95 __m512 v_val = _mm512_loadu_ps(&v[i]);
98 m_val = _mm512_fmadd_ps(v_beta1, m_val, _mm512_mul_ps(v_one_minus_beta1, g));
101 __m512 g_sq = _mm512_mul_ps(g, g);
102 v_val = _mm512_fmadd_ps(v_beta2, v_val, _mm512_mul_ps(v_one_minus_beta2, g_sq));
105 __m512 m_hat = _mm512_mul_ps(m_val, v_bc1_inv);
106 __m512 v_hat = _mm512_mul_ps(v_val, v_bc2_inv);
109 __m512 denom = _mm512_add_ps(_mm512_sqrt_ps(v_hat), v_eps);
110 __m512 update = _mm512_div_ps(m_hat, denom);
111 update = _mm512_fmadd_ps(v_weight_decay, w, update);
112 w = _mm512_fnmadd_ps(v_lr, update, w);
114 _mm512_storeu_ps(&weight[i], w);
115 _mm512_storeu_ps(&m[i], m_val);
116 _mm512_storeu_ps(&v[i], v_val);
120 for (; i < numel; ++i) {
123 m[i] = beta1 * m[i] + one_minus_beta1 * g;
124 v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
125 float m_hat = m[i] / bias_correction1;
126 float v_hat = v[i] / bias_correction2;
127 weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
130 #elif defined(__AVX__)
132 __m256 v_beta1 = _mm256_set1_ps(beta1);
133 __m256 v_beta2 = _mm256_set1_ps(beta2);
134 __m256 v_one_minus_beta1 = _mm256_set1_ps(one_minus_beta1);
135 __m256 v_one_minus_beta2 = _mm256_set1_ps(one_minus_beta2);
136 __m256 v_lr = _mm256_set1_ps(lr);
137 __m256 v_eps = _mm256_set1_ps(eps);
138 __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
139 __m256 v_bc1_inv = _mm256_set1_ps(1.0f / bias_correction1);
140 __m256 v_bc2_inv = _mm256_set1_ps(1.0f / bias_correction2);
143 for (; i + 8 <= numel; i += 8) {
144 __m256 g = _mm256_loadu_ps(&grad[i]);
145 __m256 w = _mm256_loadu_ps(&weight[i]);
146 __m256 m_val = _mm256_loadu_ps(&m[i]);
147 __m256 v_val = _mm256_loadu_ps(&v[i]);
150 m_val = _mm256_add_ps(_mm256_mul_ps(v_beta1, m_val),
151 _mm256_mul_ps(v_one_minus_beta1, g));
154 __m256 g_sq = _mm256_mul_ps(g, g);
155 v_val = _mm256_add_ps(_mm256_mul_ps(v_beta2, v_val),
156 _mm256_mul_ps(v_one_minus_beta2, g_sq));
159 __m256 m_hat = _mm256_mul_ps(m_val, v_bc1_inv);
160 __m256 v_hat = _mm256_mul_ps(v_val, v_bc2_inv);
163 __m256 denom = _mm256_add_ps(_mm256_sqrt_ps(v_hat), v_eps);
164 __m256 update = _mm256_div_ps(m_hat, denom);
165 update = _mm256_add_ps(update, _mm256_mul_ps(v_weight_decay, w));
166 w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
168 _mm256_storeu_ps(&weight[i], w);
169 _mm256_storeu_ps(&m[i], m_val);
170 _mm256_storeu_ps(&v[i], v_val);
174 for (; i < numel; ++i) {
177 m[i] = beta1 * m[i] + one_minus_beta1 * g;
178 v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
179 float m_hat = m[i] / bias_correction1;
180 float v_hat = v[i] / bias_correction2;
181 weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
184 #elif defined(__SSE2__)
186 __m128 v_beta1 = _mm_set1_ps(beta1);
187 __m128 v_beta2 = _mm_set1_ps(beta2);
188 __m128 v_one_minus_beta1 = _mm_set1_ps(one_minus_beta1);
189 __m128 v_one_minus_beta2 = _mm_set1_ps(one_minus_beta2);
190 __m128 v_lr = _mm_set1_ps(lr);
191 __m128 v_eps = _mm_set1_ps(eps);
192 __m128 v_weight_decay = _mm_set1_ps(weight_decay);
193 __m128 v_bc1_inv = _mm_set1_ps(1.0f / bias_correction1);
194 __m128 v_bc2_inv = _mm_set1_ps(1.0f / bias_correction2);
197 for (; i + 4 <= numel; i += 4) {
198 __m128 g = _mm_loadu_ps(&grad[i]);
199 __m128 w = _mm_loadu_ps(&weight[i]);
200 __m128 m_val = _mm_loadu_ps(&m[i]);
201 __m128 v_val = _mm_loadu_ps(&v[i]);
204 m_val = _mm_add_ps(_mm_mul_ps(v_beta1, m_val),
205 _mm_mul_ps(v_one_minus_beta1, g));
208 __m128 g_sq = _mm_mul_ps(g, g);
209 v_val = _mm_add_ps(_mm_mul_ps(v_beta2, v_val),
210 _mm_mul_ps(v_one_minus_beta2, g_sq));
213 __m128 m_hat = _mm_mul_ps(m_val, v_bc1_inv);
214 __m128 v_hat = _mm_mul_ps(v_val, v_bc2_inv);
217 __m128 denom = _mm_add_ps(_mm_sqrt_ps(v_hat), v_eps);
218 __m128 update = _mm_div_ps(m_hat, denom);
219 update = _mm_add_ps(update, _mm_mul_ps(v_weight_decay, w));
220 w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
222 _mm_storeu_ps(&weight[i], w);
223 _mm_storeu_ps(&m[i], m_val);
224 _mm_storeu_ps(&v[i], v_val);
228 for (; i < numel; ++i) {
231 m[i] = beta1 * m[i] + one_minus_beta1 * g;
232 v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
233 float m_hat = m[i] / bias_correction1;
234 float v_hat = v[i] / bias_correction2;
235 weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
240 for (
size_t i = 0; i < numel; ++i) {
243 m[i] = beta1 * m[i] + one_minus_beta1 * g;
244 v[i] = beta2 * v[i] + one_minus_beta2 * g * g;
245 float m_hat = m[i] / bias_correction1;
246 float v_hat = v[i] / bias_correction2;
247 weight[i] = w - lr * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * w);
276 if (!grad || !weight || !velocity || numel == 0) {
280 #if defined(__AVX512F__)
282 __m512 v_lr = _mm512_set1_ps(lr);
283 __m512 v_momentum = _mm512_set1_ps(momentum);
284 __m512 v_weight_decay = _mm512_set1_ps(weight_decay);
287 for (; i + 16 <= numel; i += 16) {
288 __m512 g = _mm512_loadu_ps(&grad[i]);
289 __m512 w = _mm512_loadu_ps(&weight[i]);
290 __m512 vel = _mm512_loadu_ps(&velocity[i]);
292 vel = _mm512_fmadd_ps(v_momentum, vel, g);
293 __m512 update = _mm512_fmadd_ps(v_weight_decay, w, vel);
294 w = _mm512_fnmadd_ps(v_lr, update, w);
296 _mm512_storeu_ps(&weight[i], w);
297 _mm512_storeu_ps(&velocity[i], vel);
300 for (; i < numel; ++i) {
301 velocity[i] = momentum * velocity[i] + grad[i];
302 weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
305 #elif defined(__AVX__)
307 __m256 v_lr = _mm256_set1_ps(lr);
308 __m256 v_momentum = _mm256_set1_ps(momentum);
309 __m256 v_weight_decay = _mm256_set1_ps(weight_decay);
312 for (; i + 8 <= numel; i += 8) {
313 __m256 g = _mm256_loadu_ps(&grad[i]);
314 __m256 w = _mm256_loadu_ps(&weight[i]);
315 __m256 vel = _mm256_loadu_ps(&velocity[i]);
318 vel = _mm256_add_ps(_mm256_mul_ps(v_momentum, vel), g);
321 __m256 update = _mm256_add_ps(vel, _mm256_mul_ps(v_weight_decay, w));
322 w = _mm256_sub_ps(w, _mm256_mul_ps(v_lr, update));
324 _mm256_storeu_ps(&weight[i], w);
325 _mm256_storeu_ps(&velocity[i], vel);
328 for (; i < numel; ++i) {
329 velocity[i] = momentum * velocity[i] + grad[i];
330 weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
333 #elif defined(__SSE2__)
335 __m128 v_lr = _mm_set1_ps(lr);
336 __m128 v_momentum = _mm_set1_ps(momentum);
337 __m128 v_weight_decay = _mm_set1_ps(weight_decay);
340 for (; i + 4 <= numel; i += 4) {
341 __m128 g = _mm_loadu_ps(&grad[i]);
342 __m128 w = _mm_loadu_ps(&weight[i]);
343 __m128 vel = _mm_loadu_ps(&velocity[i]);
345 vel = _mm_add_ps(_mm_mul_ps(v_momentum, vel), g);
346 __m128 update = _mm_add_ps(vel, _mm_mul_ps(v_weight_decay, w));
347 w = _mm_sub_ps(w, _mm_mul_ps(v_lr, update));
349 _mm_storeu_ps(&weight[i], w);
350 _mm_storeu_ps(&velocity[i], vel);
353 for (; i < numel; ++i) {
354 velocity[i] = momentum * velocity[i] + grad[i];
355 weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
360 for (
size_t i = 0; i < numel; ++i) {
361 velocity[i] = momentum * velocity[i] + grad[i];
362 weight[i] = weight[i] - lr * (velocity[i] + weight_decay * weight[i]);
376 if (!grad || numel == 0) {
379 memset(grad, 0, numel *
sizeof(
float));
394 if (!dst || !src || numel == 0) {
398 #if defined(__AVX512F__)
400 for (; i + 16 <= numel; i += 16) {
401 __m512 d = _mm512_loadu_ps(&dst[i]);
402 __m512 s = _mm512_loadu_ps(&src[i]);
403 _mm512_storeu_ps(&dst[i], _mm512_add_ps(d, s));
405 for (; i < numel; ++i) {
409 #elif defined(__AVX__)
411 for (; i + 8 <= numel; i += 8) {
412 __m256 d = _mm256_loadu_ps(&dst[i]);
413 __m256 s = _mm256_loadu_ps(&src[i]);
414 _mm256_storeu_ps(&dst[i], _mm256_add_ps(d, s));
416 for (; i < numel; ++i) {
420 #elif defined(__SSE2__)
422 for (; i + 4 <= numel; i += 4) {
423 __m128 d = _mm_loadu_ps(&dst[i]);
424 __m128 s = _mm_loadu_ps(&src[i]);
425 _mm_storeu_ps(&dst[i], _mm_add_ps(d, s));
427 for (; i < numel; ++i) {
432 for (
size_t i = 0; i < numel; ++i) {
450 if (!grad || numel == 0) {
454 #if defined(__AVX512F__)
455 __m512 v_scale = _mm512_set1_ps(scale);
457 for (; i + 16 <= numel; i += 16) {
458 __m512 g = _mm512_loadu_ps(&grad[i]);
459 _mm512_storeu_ps(&grad[i], _mm512_mul_ps(g, v_scale));
461 for (; i < numel; ++i) {
465 #elif defined(__AVX__)
466 __m256 v_scale = _mm256_set1_ps(scale);
468 for (; i + 8 <= numel; i += 8) {
469 __m256 g = _mm256_loadu_ps(&grad[i]);
470 _mm256_storeu_ps(&grad[i], _mm256_mul_ps(g, v_scale));
472 for (; i < numel; ++i) {
476 #elif defined(__SSE2__)
477 __m128 v_scale = _mm_set1_ps(scale);
479 for (; i + 4 <= numel; i += 4) {
480 __m128 g = _mm_loadu_ps(&grad[i]);
481 _mm_storeu_ps(&grad[i], _mm_mul_ps(g, v_scale));
483 for (; i < numel; ++i) {
488 for (
size_t i = 0; i < numel; ++i) {
507 if (!grad || numel == 0 || max_norm <= 0.0f) {
513 #if defined(__AVX512F__)
514 __m512 acc = _mm512_setzero_ps();
516 for (; i + 16 <= numel; i += 16) {
517 __m512 g = _mm512_loadu_ps(&grad[i]);
518 acc = _mm512_fmadd_ps(g, g, acc);
520 sum_sq = _mm512_reduce_add_ps(acc);
521 for (; i < numel; ++i) {
522 sum_sq += (double)grad[i] * (
double)grad[i];
525 #elif defined(__AVX__)
526 __m256 acc = _mm256_setzero_ps();
528 for (; i + 8 <= numel; i += 8) {
529 __m256 g = _mm256_loadu_ps(&grad[i]);
530 acc = _mm256_add_ps(acc, _mm256_mul_ps(g, g));
533 __m128 hi = _mm256_extractf128_ps(acc, 1);
534 __m128 lo = _mm256_castps256_ps128(acc);
535 __m128 sum4 = _mm_add_ps(lo, hi);
536 __m128 shuf = _mm_movehdup_ps(sum4);
537 __m128 sums = _mm_add_ps(sum4, shuf);
538 shuf = _mm_movehl_ps(shuf, sums);
539 sums = _mm_add_ss(sums, shuf);
540 sum_sq = _mm_cvtss_f32(sums);
541 for (; i < numel; ++i) {
542 sum_sq += (double)grad[i] * (
double)grad[i];
545 #elif defined(__SSE2__)
546 __m128 acc = _mm_setzero_ps();
548 for (; i + 4 <= numel; i += 4) {
549 __m128 g = _mm_loadu_ps(&grad[i]);
550 acc = _mm_add_ps(acc, _mm_mul_ps(g, g));
553 __m128 shuf = _mm_shuffle_ps(acc, acc, _MM_SHUFFLE(2, 3, 0, 1));
554 __m128 sums = _mm_add_ps(acc, shuf);
555 shuf = _mm_movehl_ps(shuf, sums);
556 sums = _mm_add_ss(sums, shuf);
557 sum_sq = _mm_cvtss_f32(sums);
558 for (; i < numel; ++i) {
559 sum_sq += (double)grad[i] * (
double)grad[i];
563 for (
size_t i = 0; i < numel; ++i) {
564 sum_sq += (double)grad[i] * (
double)grad[i];
568 float norm = sqrtf((
float)sum_sq);
571 if (norm > max_norm) {
572 float scale = max_norm / norm;
void sgd_momentum_update_f32(const float *grad, float *weight, float *velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum optimizer update (fp32 version)
float gradient_clip_norm_f32(float *grad, size_t numel, float max_norm)
Clip gradient norm (fp32)
void gradient_scale_f32(float *grad, size_t numel, float scale)
Scale gradients by a constant: grad *= scale (fp32)
void adamw_update_f32(const float *grad, float *weight, float *m, float *v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (fp32 version)
void gradient_accumulate_f32(float *dst, const float *src, size_t numel)
Accumulate gradients: dst += src (fp32)
void zero_gradients_f32(float *grad, size_t numel)
Zero out gradient buffer (fp32)