37 #if defined(__AVX512F__)
38 #include <immintrin.h>
52 #define MLP_TILE_SIZE 64
55 #define OUTPUT_TILE_SIZE 32
61 #if defined(__AVX512F__)
64 static inline __m512 silu_avx512(__m512 x) {
66 __m512 neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x);
71 __m512 ln2 = _mm512_set1_ps(0.6931471805599453f);
72 __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
75 neg_x = _mm512_max_ps(neg_x, _mm512_set1_ps(-88.0f));
76 neg_x = _mm512_min_ps(neg_x, _mm512_set1_ps(88.0f));
80 #if defined(__AVX512ER__)
81 __m512 exp_neg_x = _mm512_exp2a23_ps(_mm512_mul_ps(neg_x, log2e));
86 __m512 t = _mm512_mul_ps(neg_x, log2e);
87 __m512i ti = _mm512_cvtps_epi32(t);
88 __m512 tf = _mm512_sub_ps(t, _mm512_cvtepi32_ps(ti));
89 tf = _mm512_mul_ps(tf, ln2);
92 __m512 c0 = _mm512_set1_ps(1.0f);
93 __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
94 __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
95 __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
96 __m512 c4 = _mm512_set1_ps(0.009618129107628477f);
98 __m512 p = _mm512_fmadd_ps(c4, tf, c3);
99 p = _mm512_fmadd_ps(p, tf, c2);
100 p = _mm512_fmadd_ps(p, tf, c1);
101 p = _mm512_fmadd_ps(p, tf, c0);
104 __m512 exp_neg_x = _mm512_scalef_ps(p, _mm512_cvtepi32_ps(ti));
108 __m512 one = _mm512_set1_ps(1.0f);
109 __m512 sigmoid = _mm512_div_ps(one, _mm512_add_ps(one, exp_neg_x));
112 return _mm512_mul_ps(x, sigmoid);
116 static inline float hsum512_ps(__m512 v) {
119 __m256 lo = _mm512_castps512_ps256(v);
120 __m256 hi = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(v), 1));
121 __m256 sum256 = _mm256_add_ps(lo, hi);
123 __m128 lo128 = _mm256_castps256_ps128(sum256);
124 __m128 hi128 = _mm256_extractf128_ps(sum256, 1);
125 __m128 sum128 = _mm_add_ps(lo128, hi128);
127 sum128 = _mm_hadd_ps(sum128, sum128);
128 sum128 = _mm_hadd_ps(sum128, sum128);
129 return _mm_cvtss_f32(sum128);
135 return x / (1.0f + expf(-x));
166 #if defined(__AVX512F__)
169 memcpy(output, b_down, D *
sizeof(
float));
171 memset(output, 0, D *
sizeof(
float));
179 if (D > 4096)
return;
185 memset(local_output, 0, D *
sizeof(
float));
187 #pragma omp for schedule(static)
190 int tile_size = tile_end - t;
195 for (
int j = t; j < tile_end; j++) {
196 const float *wg_row = &W_gate[j * D];
197 const float *wu_row = &W_up[j * D];
200 __m512 gate_acc = _mm512_setzero_ps();
201 __m512 up_acc = _mm512_setzero_ps();
204 for (; k <= D - 16; k += 16) {
205 __m512 x_vec = _mm512_loadu_ps(&x[k]);
206 __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
207 __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
209 gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
210 up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
213 float gate = hsum512_ps(gate_acc);
214 float up = hsum512_ps(up_acc);
218 gate += x[k] * wg_row[k];
219 up += x[k] * wu_row[k];
223 if (b_gate) gate += b_gate[j];
224 if (b_up) up += b_up[j];
232 for (
int i = 0; i < D; i++) {
233 const float *wd_row = &W_down[i * Hff + t];
235 __m512 acc = _mm512_setzero_ps();
237 for (; j <= tile_size - 16; j += 16) {
238 __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
239 __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
240 acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
243 float sum = hsum512_ps(acc);
244 for (; j < tile_size; j++) {
245 sum += swiglu_tile[j] * wd_row[j];
248 local_output[i] += sum;
255 for (
int i = 0; i < D; i++) {
256 output[i] += local_output[i];
265 memcpy(output, b_down, D *
sizeof(
float));
267 memset(output, 0, D *
sizeof(
float));
272 int tile_size = tile_end - t;
276 for (
int j = t; j < tile_end; j++) {
280 for (
int k = 0; k < D; k++) {
281 gate += x[k] * W_gate[j * D + k];
282 up += x[k] * W_up[j * D + k];
285 if (b_gate) gate += b_gate[j];
286 if (b_up) up += b_up[j];
291 for (
int i = 0; i < D; i++) {
292 for (
int j = 0; j < tile_size; j++) {
293 output[i] += swiglu_tile[j] * W_down[i * Hff + t + j];
302 const float *x,
const float *W_gate,
const float *W_up,
const float *W_down,
303 const float *b_gate,
const float *b_up,
const float *b_down,
304 float *output,
int D,
int Hff);
316 #define MAX_SWIGLU_STACK 8192
333 b_gate, b_up, b_down, output, D, Hff);
337 #if defined(__AVX512F__)
342 #pragma omp parallel for schedule(static)
343 for (
int j = 0; j < Hff; j++) {
344 const float *wg_row = &W_gate[j * D];
345 const float *wu_row = &W_up[j * D];
347 __m512 gate_acc = _mm512_setzero_ps();
348 __m512 up_acc = _mm512_setzero_ps();
351 for (; k <= D - 16; k += 16) {
352 __m512 x_vec = _mm512_loadu_ps(&x[k]);
353 __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
354 __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
356 gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
357 up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
360 float gate = hsum512_ps(gate_acc);
361 float up = hsum512_ps(up_acc);
364 gate += x[k] * wg_row[k];
365 up += x[k] * wu_row[k];
368 if (b_gate) gate += b_gate[j];
369 if (b_up) up += b_up[j];
375 #pragma omp parallel for schedule(static)
376 for (
int i = 0; i < D; i++) {
377 const float *wd_row = &W_down[i * Hff];
379 __m512 acc = _mm512_setzero_ps();
381 for (; j <= Hff - 16; j += 16) {
382 __m512 sw_vec = _mm512_loadu_ps(&swiglu[j]);
383 __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
384 acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
387 float sum = hsum512_ps(acc);
388 for (; j < Hff; j++) {
389 sum += swiglu[j] * wd_row[j];
392 output[i] = sum + (b_down ? b_down[i] : 0.0f);
399 for (
int j = 0; j < Hff; j++) {
400 float gate = 0.0f, up = 0.0f;
401 for (
int k = 0; k < D; k++) {
402 gate += x[k] * W_gate[j * D + k];
403 up += x[k] * W_up[j * D + k];
405 if (b_gate) gate += b_gate[j];
406 if (b_up) up += b_up[j];
410 for (
int i = 0; i < D; i++) {
412 for (
int j = 0; j < Hff; j++) {
413 sum += swiglu[j] * W_down[i * Hff + j];
415 output[i] = sum + (b_down ? b_down[i] : 0.0f);
445 const int TILE = 256;
447 #if defined(__AVX512F__)
449 #pragma omp parallel for schedule(static)
450 for (
int i = 0; i < D; i++) {
451 output[i] = b_down ? b_down[i] : 0.0f;
455 for (
int t = 0; t < Hff; t += TILE) {
456 int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
457 int tile_size = tile_end - t;
462 #pragma omp parallel for schedule(static)
463 for (
int jj = 0; jj < tile_size; jj++) {
465 const float *wg_row = &W_gate[j * D];
466 const float *wu_row = &W_up[j * D];
468 __m512 gate_acc = _mm512_setzero_ps();
469 __m512 up_acc = _mm512_setzero_ps();
472 for (; k <= D - 16; k += 16) {
473 __m512 x_vec = _mm512_loadu_ps(&x[k]);
474 __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
475 __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
477 gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
478 up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
481 float gate = hsum512_ps(gate_acc);
482 float up = hsum512_ps(up_acc);
485 gate += x[k] * wg_row[k];
486 up += x[k] * wu_row[k];
489 if (b_gate) gate += b_gate[j];
490 if (b_up) up += b_up[j];
496 #pragma omp parallel for schedule(static)
497 for (
int i = 0; i < D; i++) {
498 const float *wd_row = &W_down[i * Hff + t];
500 __m512 acc = _mm512_setzero_ps();
502 for (; j <= tile_size - 16; j += 16) {
503 __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
504 __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
505 acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
508 float sum = hsum512_ps(acc);
509 for (; j < tile_size; j++) {
510 sum += swiglu_tile[j] * wd_row[j];
521 for (
int i = 0; i < D; i++) {
522 output[i] = b_down ? b_down[i] : 0.0f;
525 for (
int t = 0; t < Hff; t += TILE) {
526 int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
528 float swiglu_tile[256];
530 for (
int j = t; j < tile_end; j++) {
531 float gate = 0.0f, up = 0.0f;
532 for (
int k = 0; k < D; k++) {
533 gate += x[k] * W_gate[j * D + k];
534 up += x[k] * W_up[j * D + k];
536 if (b_gate) gate += b_gate[j];
537 if (b_up) up += b_up[j];
541 for (
int i = 0; i < D; i++) {
542 for (
int j = t; j < tile_end; j++) {
543 output[i] += swiglu_tile[j - t] * W_down[i * Hff + j];
void fused_mlp_swiglu_decode_tiled(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void fused_mlp_swiglu_decode(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
static float silu_scalar(float x)
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)