19 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
20 #include <immintrin.h>
25 #if defined(__AVX512F__)
26 static inline __m512 exp512_approx(__m512 x) {
28 x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
29 x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
32 const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
33 const __m512 c1 = _mm512_set1_ps(0.693359375f);
34 const __m512 c2 = _mm512_set1_ps(-2.12194440e-4f);
36 __m512 t = _mm512_mul_ps(x, log2e);
37 __m512 ti = _mm512_roundscale_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
40 __m512 rx = _mm512_sub_ps(x, _mm512_mul_ps(ti, c1));
41 rx = _mm512_sub_ps(rx, _mm512_mul_ps(ti, c2));
44 const __m512 p0 = _mm512_set1_ps(1.0f);
45 const __m512 p1 = _mm512_set1_ps(0.6931471805599453f);
46 const __m512 p2 = _mm512_set1_ps(0.24022650695910071f);
47 const __m512 p3 = _mm512_set1_ps(0.05550410866482157f);
48 const __m512 p4 = _mm512_set1_ps(0.009618129107628477f);
50 __m512 poly = _mm512_fmadd_ps(p4, rx, p3);
51 poly = _mm512_fmadd_ps(poly, rx, p2);
52 poly = _mm512_fmadd_ps(poly, rx, p1);
53 poly = _mm512_fmadd_ps(poly, rx, p0);
56 __m512i ti_int = _mm512_cvtps_epi32(ti);
57 ti_int = _mm512_add_epi32(ti_int, _mm512_set1_epi32(127));
58 ti_int = _mm512_slli_epi32(ti_int, 23);
59 __m512 scale = _mm512_castsi512_ps(ti_int);
61 return _mm512_mul_ps(poly, scale);
67 static inline __m256 exp256_approx(__m256 x) {
69 x = _mm256_max_ps(x, _mm256_set1_ps(-88.0f));
70 x = _mm256_min_ps(x, _mm256_set1_ps(88.0f));
72 const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
73 const __m256 c1 = _mm256_set1_ps(0.693359375f);
74 const __m256 c2 = _mm256_set1_ps(-2.12194440e-4f);
76 __m256 t = _mm256_mul_ps(x, log2e);
77 __m256 ti = _mm256_round_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
79 __m256 rx = _mm256_sub_ps(x, _mm256_mul_ps(ti, c1));
80 rx = _mm256_sub_ps(rx, _mm256_mul_ps(ti, c2));
83 const __m256 p0 = _mm256_set1_ps(1.0f);
84 const __m256 p1 = _mm256_set1_ps(0.6931471805599453f);
85 const __m256 p2 = _mm256_set1_ps(0.24022650695910071f);
86 const __m256 p3 = _mm256_set1_ps(0.05550410866482157f);
87 const __m256 p4 = _mm256_set1_ps(0.009618129107628477f);
89 __m256 poly = _mm256_fmadd_ps(p4, rx, p3);
90 poly = _mm256_fmadd_ps(poly, rx, p2);
91 poly = _mm256_fmadd_ps(poly, rx, p1);
92 poly = _mm256_fmadd_ps(poly, rx, p0);
95 __m256i ti_int = _mm256_cvtps_epi32(ti);
96 ti_int = _mm256_add_epi32(ti_int, _mm256_set1_epi32(127));
97 ti_int = _mm256_slli_epi32(ti_int, 23);
98 __m256 scale = _mm256_castsi256_ps(ti_int);
100 return _mm256_mul_ps(poly, scale);
105 #if defined(__AVX__) || defined(__AVX2__)
106 static inline float hmax256_ps(__m256 v) {
107 __m128 hi = _mm256_extractf128_ps(v, 1);
108 __m128 lo = _mm256_castps256_ps128(v);
109 __m128 max128 = _mm_max_ps(lo, hi);
110 max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(2, 3, 0, 1)));
111 max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(1, 0, 3, 2)));
112 return _mm_cvtss_f32(max128);
116 static inline float hsum256_ps_softmax(__m256 v) {
117 __m128 hi = _mm256_extractf128_ps(v, 1);
118 __m128 lo = _mm256_castps256_ps128(v);
119 __m128 sum128 = _mm_add_ps(lo, hi);
120 sum128 = _mm_hadd_ps(sum128, sum128);
121 sum128 = _mm_hadd_ps(sum128, sum128);
122 return _mm_cvtss_f32(sum128);
147 int aligned_context_window)
149 for (
int h = 0; h < num_heads; ++h) {
150 for (
int i = 0; i < num_tokens; ++i) {
151 int base = h * aligned_context_window * aligned_context_window
152 + i * aligned_context_window;
153 float *row = &scores[base];
156 #if defined(__AVX512F__)
158 __m512 max_vec = _mm512_set1_ps(-INFINITY);
160 for (; j + 16 <= len; j += 16) {
161 __m512 v = _mm512_loadu_ps(&row[j]);
162 max_vec = _mm512_max_ps(max_vec, v);
164 float max_val = _mm512_reduce_max_ps(max_vec);
165 for (; j < len; ++j) {
166 if (row[j] > max_val) max_val = row[j];
170 __m512 max_broadcast = _mm512_set1_ps(max_val);
171 __m512 sum_vec = _mm512_setzero_ps();
173 for (; j + 16 <= len; j += 16) {
174 __m512 v = _mm512_loadu_ps(&row[j]);
175 __m512 e = exp512_approx(_mm512_sub_ps(v, max_broadcast));
176 _mm512_storeu_ps(&row[j], e);
177 sum_vec = _mm512_add_ps(sum_vec, e);
179 float sum = _mm512_reduce_add_ps(sum_vec);
180 for (; j < len; ++j) {
181 float e = expf(row[j] - max_val);
187 float inv_sum = 1.0f / sum;
188 __m512 inv_sum_vec = _mm512_set1_ps(inv_sum);
190 for (; j + 16 <= len; j += 16) {
191 __m512 v = _mm512_loadu_ps(&row[j]);
192 _mm512_storeu_ps(&row[j], _mm512_mul_ps(v, inv_sum_vec));
194 for (; j < len; ++j) {
199 __m512 zero = _mm512_setzero_ps();
200 for (; j + 16 <= num_tokens; j += 16) {
201 _mm512_storeu_ps(&row[j], zero);
203 for (; j < num_tokens; ++j) {
207 #elif defined(__AVX2__)
209 __m256 max_vec = _mm256_set1_ps(-INFINITY);
211 for (; j + 8 <= len; j += 8) {
212 __m256 v = _mm256_loadu_ps(&row[j]);
213 max_vec = _mm256_max_ps(max_vec, v);
215 float max_val = hmax256_ps(max_vec);
216 for (; j < len; ++j) {
217 if (row[j] > max_val) max_val = row[j];
221 __m256 max_broadcast = _mm256_set1_ps(max_val);
222 __m256 sum_vec = _mm256_setzero_ps();
224 for (; j + 8 <= len; j += 8) {
225 __m256 v = _mm256_loadu_ps(&row[j]);
226 __m256 e = exp256_approx(_mm256_sub_ps(v, max_broadcast));
227 _mm256_storeu_ps(&row[j], e);
228 sum_vec = _mm256_add_ps(sum_vec, e);
230 float sum = hsum256_ps_softmax(sum_vec);
231 for (; j < len; ++j) {
232 float e = expf(row[j] - max_val);
238 float inv_sum = 1.0f / sum;
239 __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
241 for (; j + 8 <= len; j += 8) {
242 __m256 v = _mm256_loadu_ps(&row[j]);
243 _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
245 for (; j < len; ++j) {
250 __m256 zero = _mm256_setzero_ps();
251 for (; j + 8 <= num_tokens; j += 8) {
252 _mm256_storeu_ps(&row[j], zero);
254 for (; j < num_tokens; ++j) {
258 #elif defined(__AVX__)
260 __m256 max_vec = _mm256_set1_ps(-INFINITY);
262 for (; j + 8 <= len; j += 8) {
263 __m256 v = _mm256_loadu_ps(&row[j]);
264 max_vec = _mm256_max_ps(max_vec, v);
266 float max_val = hmax256_ps(max_vec);
267 for (; j < len; ++j) {
268 if (row[j] > max_val) max_val = row[j];
273 for (j = 0; j < len; ++j) {
274 float e = expf(row[j] - max_val);
280 float inv_sum = 1.0f / sum;
281 __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
283 for (; j + 8 <= len; j += 8) {
284 __m256 v = _mm256_loadu_ps(&row[j]);
285 _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
287 for (; j < len; ++j) {
292 __m256 zero = _mm256_setzero_ps();
293 for (; j + 8 <= num_tokens; j += 8) {
294 _mm256_storeu_ps(&row[j], zero);
296 for (; j < num_tokens; ++j) {
302 float max_val = row[0];
303 for (
int j = 1; j < len; ++j) {
304 if (row[j] > max_val) max_val = row[j];
308 for (
int j = 0; j < len; ++j) {
309 float e = expf(row[j] - max_val);
314 float inv_sum = 1.0f / sum;
315 for (
int j = 0; j < len; ++j) {
319 for (
int j = len; j < num_tokens; ++j) {
342 int aligned_context_window)
344 for (
int h = 0; h < num_heads; ++h) {
345 for (
int i = 0; i < num_tokens; ++i) {
346 int base = h * aligned_context_window * aligned_context_window
347 + i * aligned_context_window;
348 float *row = &scores[base];
352 float max_val = -INFINITY;
353 for (
int j = 0; j < len; ++j) {
354 if (row[j] > max_val) max_val = row[j];
359 for (
int j = 0; j < len; ++j) {
360 float e = expf(row[j] - max_val);
366 float inv_sum = 1.0f / sum;
367 for (
int j = 0; j < len; ++j) {
372 for (
int j = len; j < num_tokens; ++j) {
383 const float *weights,
386 int aligned_context_window)
391 for (
int h = 0; h < H; ++h) {
392 for (
int i = 0; i < T; ++i) {
393 int base = h * aligned_context_window * aligned_context_window
394 + i * aligned_context_window;
395 float *drow = &d_scores[base];
396 const float *wrow = &weights[base];
399 #if defined(__AVX512F__)
401 __m512 dot_vec = _mm512_setzero_ps();
403 for (; j + 16 <= len; j += 16) {
404 __m512 w = _mm512_loadu_ps(&wrow[j]);
405 __m512 dw = _mm512_loadu_ps(&drow[j]);
406 dot_vec = _mm512_fmadd_ps(w, dw, dot_vec);
408 float dot_product = _mm512_reduce_add_ps(dot_vec);
409 for (; j < len; ++j) {
410 dot_product += wrow[j] * drow[j];
414 __m512 dot_broadcast = _mm512_set1_ps(dot_product);
416 for (; j + 16 <= len; j += 16) {
417 __m512 w = _mm512_loadu_ps(&wrow[j]);
418 __m512 dw = _mm512_loadu_ps(&drow[j]);
419 __m512 diff = _mm512_sub_ps(dw, dot_broadcast);
420 __m512 result = _mm512_mul_ps(w, diff);
421 _mm512_storeu_ps(&drow[j], result);
423 for (; j < len; ++j) {
424 drow[j] = wrow[j] * (drow[j] - dot_product);
428 __m512 zero = _mm512_setzero_ps();
429 for (; j + 16 <= T; j += 16) {
430 _mm512_storeu_ps(&drow[j], zero);
436 #elif defined(__AVX__)
438 __m256 dot_vec = _mm256_setzero_ps();
440 for (; j + 8 <= len; j += 8) {
441 __m256 w = _mm256_loadu_ps(&wrow[j]);
442 __m256 dw = _mm256_loadu_ps(&drow[j]);
444 __m256 prod = _mm256_mul_ps(w, dw);
445 dot_vec = _mm256_add_ps(dot_vec, prod);
447 float dot_product = hsum256_ps_softmax(dot_vec);
448 for (; j < len; ++j) {
449 dot_product += wrow[j] * drow[j];
453 __m256 dot_broadcast = _mm256_set1_ps(dot_product);
455 for (; j + 8 <= len; j += 8) {
456 __m256 w = _mm256_loadu_ps(&wrow[j]);
457 __m256 dw = _mm256_loadu_ps(&drow[j]);
458 __m256 diff = _mm256_sub_ps(dw, dot_broadcast);
459 __m256 result = _mm256_mul_ps(w, diff);
460 _mm256_storeu_ps(&drow[j], result);
462 for (; j < len; ++j) {
463 drow[j] = wrow[j] * (drow[j] - dot_product);
467 __m256 zero = _mm256_setzero_ps();
468 for (; j + 8 <= T; j += 8) {
469 _mm256_storeu_ps(&drow[j], zero);
477 float dot_product = 0.0f;
478 for (
int j = 0; j < len; ++j) {
479 dot_product += wrow[j] * drow[j];
482 for (
int j = 0; j < len; ++j) {
483 drow[j] = wrow[j] * (drow[j] - dot_product);
486 for (
int j = len; j < T; ++j) {
void backward_causal_softmax_head_major(float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)