← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mlp_fused_decode.c
Go to the documentation of this file.
1 /**
2  * @file mlp_fused_decode.c
3  * @brief Fully fused MLP decode kernel (T=1 token generation)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * LEGACY: This file is from v6/v6.5 and kept for backward compatibility.
15  *
16  * This kernel fuses the ENTIRE MLP block into a single pass:
17  * output = Down(SwiGLU(Gate(x), Up(x))) + residual
18  *
19  * Key optimization: The intermediate SwiGLU values (~4864 floats = 19KB for Qwen2)
20  * NEVER touch DRAM. They stay in L1/L2 cache through tiling.
21  *
22  * Target: Intel Xeon 5th Gen (Emerald Rapids) with AVX-512 and AMX
23  *
24  * Memory traffic comparison (Qwen2-0.5B, D=896, Hff=4864):
25  * Unfused: 76 KB activation traffic (38KB write + 38KB read)
26  * Fused: 0 KB activation traffic (tiles stay in L1)
27  *
28  * Weight layout expected: Row-major, transposed for matvec
29  * W_gate[Hff, D], W_up[Hff, D], W_down[D, Hff]
30  */
31 
32 #include "ckernel_engine.h"
33 #include <math.h>
34 #include <stdlib.h> // for aligned_alloc in v1
35 #include <string.h>
36 
37 #if defined(__AVX512F__)
38 #include <immintrin.h>
39 #endif
40 
41 #ifdef _OPENMP
42 #include <omp.h>
43 #endif
44 
45 // =============================================================================
46 // Configuration for Xeon 5th Gen
47 // =============================================================================
48 
49 // L1 data cache: 48 KB per core on Sapphire/Emerald Rapids
50 // L2 cache: 2 MB per core
51 // We use a tile size that fits comfortably in L1 with room for weights
52 #define MLP_TILE_SIZE 64 // 64 intermediate values = 256 bytes
53 
54 // For down projection accumulation, we tile the output dimension
55 #define OUTPUT_TILE_SIZE 32 // 32 output values accumulated at once
56 
57 // =============================================================================
58 // AVX-512 Helpers
59 // =============================================================================
60 
61 #if defined(__AVX512F__)
62 // Fast SiLU (x * sigmoid(x)) using AVX-512
63 // sigmoid(x) = 1 / (1 + exp(-x))
64 static inline __m512 silu_avx512(__m512 x) {
65  // Compute -x
66  __m512 neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x);
67 
68  // exp(-x) approximation using polynomial (faster than _mm512_exp_ps)
69  // We use the identity: exp(-x) = 2^(-x/ln2)
70  // For better accuracy with larger ranges, we clamp
71  __m512 ln2 = _mm512_set1_ps(0.6931471805599453f);
72  __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
73 
74  // Clamp to avoid overflow/underflow
75  neg_x = _mm512_max_ps(neg_x, _mm512_set1_ps(-88.0f));
76  neg_x = _mm512_min_ps(neg_x, _mm512_set1_ps(88.0f));
77 
78  // Use the built-in exp if available (Xeon has fast transcendentals)
79  // Otherwise fall back to polynomial approximation
80 #if defined(__AVX512ER__) // Knights Landing/Mill have fast exp
81  __m512 exp_neg_x = _mm512_exp2a23_ps(_mm512_mul_ps(neg_x, log2e));
82 #else
83  // Polynomial approximation for exp(-x)
84  // exp(x) ≈ 1 + x + x²/2 + x³/6 + x⁴/24 (good for |x| < 4)
85  // For larger x, we use range reduction
86  __m512 t = _mm512_mul_ps(neg_x, log2e);
87  __m512i ti = _mm512_cvtps_epi32(t);
88  __m512 tf = _mm512_sub_ps(t, _mm512_cvtepi32_ps(ti));
89  tf = _mm512_mul_ps(tf, ln2);
90 
91  // Polynomial for 2^frac
92  __m512 c0 = _mm512_set1_ps(1.0f);
93  __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
94  __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
95  __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
96  __m512 c4 = _mm512_set1_ps(0.009618129107628477f);
97 
98  __m512 p = _mm512_fmadd_ps(c4, tf, c3);
99  p = _mm512_fmadd_ps(p, tf, c2);
100  p = _mm512_fmadd_ps(p, tf, c1);
101  p = _mm512_fmadd_ps(p, tf, c0);
102 
103  // Scale by 2^int
104  __m512 exp_neg_x = _mm512_scalef_ps(p, _mm512_cvtepi32_ps(ti));
105 #endif
106 
107  // sigmoid = 1 / (1 + exp(-x))
108  __m512 one = _mm512_set1_ps(1.0f);
109  __m512 sigmoid = _mm512_div_ps(one, _mm512_add_ps(one, exp_neg_x));
110 
111  // silu = x * sigmoid(x)
112  return _mm512_mul_ps(x, sigmoid);
113 }
114 
115 // Horizontal sum of __m512 (AVX-512F only, no DQ required)
116 static inline float hsum512_ps(__m512 v) {
117  // Use shuffle-based reduction (AVX-512F compatible)
118  // Reduce 16 -> 8
119  __m256 lo = _mm512_castps512_ps256(v);
120  __m256 hi = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(v), 1));
121  __m256 sum256 = _mm256_add_ps(lo, hi);
122  // Reduce 8 -> 4
123  __m128 lo128 = _mm256_castps256_ps128(sum256);
124  __m128 hi128 = _mm256_extractf128_ps(sum256, 1);
125  __m128 sum128 = _mm_add_ps(lo128, hi128);
126  // Reduce 4 -> 2 -> 1
127  sum128 = _mm_hadd_ps(sum128, sum128);
128  sum128 = _mm_hadd_ps(sum128, sum128);
129  return _mm_cvtss_f32(sum128);
130 }
131 #endif
132 
133 // Scalar SiLU for fallback and remainder
134 static inline float silu_scalar(float x) {
135  return x / (1.0f + expf(-x));
136 }
137 
138 // =============================================================================
139 // Fully Fused MLP Decode (Main Kernel)
140 // =============================================================================
141 //
142 // Computes: output[D] = SwiGLU_MLP(x[D]) where
143 // gate = x @ W_gate^T + b_gate
144 // up = x @ W_up^T + b_up
145 // swiglu = SiLU(gate) * up
146 // output = swiglu @ W_down^T + b_down
147 //
148 // Tiling strategy:
149 // - Process intermediate dimension in tiles of MLP_TILE_SIZE
150 // - For each tile: compute gate, up, swiglu (stays in registers/L1)
151 // - Immediately accumulate into output via W_down
152 // - Swiglu tile NEVER written to DRAM
153 //
155  const float *x, // [D] input (after RMSNorm)
156  const float *W_gate, // [Hff, D] gate projection weights
157  const float *W_up, // [Hff, D] up projection weights
158  const float *W_down, // [D, Hff] down projection weights
159  const float *b_gate, // [Hff] gate bias (can be NULL)
160  const float *b_up, // [Hff] up bias (can be NULL)
161  const float *b_down, // [D] down bias (can be NULL)
162  float *output, // [D] output
163  int D, // hidden dimension (e.g., 896)
164  int Hff) // intermediate dimension (e.g., 4864)
165 {
166 #if defined(__AVX512F__)
167  // Initialize output with bias or zero
168  if (b_down) {
169  memcpy(output, b_down, D * sizeof(float));
170  } else {
171  memset(output, 0, D * sizeof(float));
172  }
173 
174  // Process intermediate dimension in tiles
175  // Each tile computes MLP_TILE_SIZE swiglu values and immediately
176  // accumulates them into the output
177 
178  /* Bounds check for stack allocation */
179  if (D > 4096) return;
180 
181  #pragma omp parallel
182  {
183  /* Thread-local accumulator on stack (no malloc!) */
184  float local_output[4096] __attribute__((aligned(64)));
185  memset(local_output, 0, D * sizeof(float));
186 
187  #pragma omp for schedule(static)
188  for (int t = 0; t < Hff; t += MLP_TILE_SIZE) {
189  int tile_end = (t + MLP_TILE_SIZE < Hff) ? t + MLP_TILE_SIZE : Hff;
190  int tile_size = tile_end - t;
191 
192  // Compute SwiGLU for this tile (stays in L1 cache)
193  float swiglu_tile[MLP_TILE_SIZE] __attribute__((aligned(64)));
194 
195  for (int j = t; j < tile_end; j++) {
196  const float *wg_row = &W_gate[j * D];
197  const float *wu_row = &W_up[j * D];
198 
199  // Compute gate = x @ W_gate[j] using AVX-512
200  __m512 gate_acc = _mm512_setzero_ps();
201  __m512 up_acc = _mm512_setzero_ps();
202 
203  int k = 0;
204  for (; k <= D - 16; k += 16) {
205  __m512 x_vec = _mm512_loadu_ps(&x[k]);
206  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
207  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
208 
209  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
210  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
211  }
212 
213  float gate = hsum512_ps(gate_acc);
214  float up = hsum512_ps(up_acc);
215 
216  // Scalar remainder
217  for (; k < D; k++) {
218  gate += x[k] * wg_row[k];
219  up += x[k] * wu_row[k];
220  }
221 
222  // Add biases
223  if (b_gate) gate += b_gate[j];
224  if (b_up) up += b_up[j];
225 
226  // SwiGLU: SiLU(gate) * up
227  swiglu_tile[j - t] = silu_scalar(gate) * up;
228  }
229 
230  // Accumulate into output via W_down
231  // output[i] += sum_j(swiglu_tile[j] * W_down[i, t+j])
232  for (int i = 0; i < D; i++) {
233  const float *wd_row = &W_down[i * Hff + t];
234 
235  __m512 acc = _mm512_setzero_ps();
236  int j = 0;
237  for (; j <= tile_size - 16; j += 16) {
238  __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
239  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
240  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
241  }
242 
243  float sum = hsum512_ps(acc);
244  for (; j < tile_size; j++) {
245  sum += swiglu_tile[j] * wd_row[j];
246  }
247 
248  local_output[i] += sum;
249  }
250  }
251 
252  // Reduce thread-local outputs
253  #pragma omp critical
254  {
255  for (int i = 0; i < D; i++) {
256  output[i] += local_output[i];
257  }
258  }
259  /* No free - stack buffer auto-deallocates */
260  }
261 
262 #else
263  // Scalar fallback (same algorithm, no SIMD)
264  if (b_down) {
265  memcpy(output, b_down, D * sizeof(float));
266  } else {
267  memset(output, 0, D * sizeof(float));
268  }
269 
270  for (int t = 0; t < Hff; t += MLP_TILE_SIZE) {
271  int tile_end = (t + MLP_TILE_SIZE < Hff) ? t + MLP_TILE_SIZE : Hff;
272  int tile_size = tile_end - t;
273 
274  float swiglu_tile[MLP_TILE_SIZE];
275 
276  for (int j = t; j < tile_end; j++) {
277  float gate = 0.0f;
278  float up = 0.0f;
279 
280  for (int k = 0; k < D; k++) {
281  gate += x[k] * W_gate[j * D + k];
282  up += x[k] * W_up[j * D + k];
283  }
284 
285  if (b_gate) gate += b_gate[j];
286  if (b_up) up += b_up[j];
287 
288  swiglu_tile[j - t] = silu_scalar(gate) * up;
289  }
290 
291  for (int i = 0; i < D; i++) {
292  for (int j = 0; j < tile_size; j++) {
293  output[i] += swiglu_tile[j] * W_down[i * Hff + t + j];
294  }
295  }
296  }
297 #endif
298 }
299 
300 // Forward declaration for fallback
302  const float *x, const float *W_gate, const float *W_up, const float *W_down,
303  const float *b_gate, const float *b_up, const float *b_down,
304  float *output, int D, int Hff);
305 
306 // =============================================================================
307 // Optimized Version: Two-Phase with Stack Buffer (Best for 24+ cores)
308 // =============================================================================
309 //
310 // Phase 1: All threads compute swiglu values in parallel (no reduction needed)
311 // Phase 2: All threads compute output values in parallel (no reduction needed)
312 //
313 // Uses a stack-allocated buffer that fits in L2 cache.
314 // For Hff > MAX_SWIGLU_STACK, falls back to tiled version.
315 //
316 #define MAX_SWIGLU_STACK 8192 // 32KB buffer, fits in L2
317 
319  const float *x, // [D]
320  const float *W_gate, // [Hff, D]
321  const float *W_up, // [Hff, D]
322  const float *W_down, // [D, Hff]
323  const float *b_gate, // [Hff] or NULL
324  const float *b_up, // [Hff] or NULL
325  const float *b_down, // [D] or NULL
326  float *output, // [D]
327  int D,
328  int Hff)
329 {
330  // For large Hff, use tiled version to avoid stack overflow
331  if (Hff > MAX_SWIGLU_STACK) {
332  fused_mlp_swiglu_decode_tiled(x, W_gate, W_up, W_down,
333  b_gate, b_up, b_down, output, D, Hff);
334  return;
335  }
336 
337 #if defined(__AVX512F__)
338  // Stack-allocated swiglu buffer (max 32KB)
339  float swiglu[MAX_SWIGLU_STACK] __attribute__((aligned(64)));
340 
341  // Phase 1: Compute all swiglu values (parallelize over Hff)
342  #pragma omp parallel for schedule(static)
343  for (int j = 0; j < Hff; j++) {
344  const float *wg_row = &W_gate[j * D];
345  const float *wu_row = &W_up[j * D];
346 
347  __m512 gate_acc = _mm512_setzero_ps();
348  __m512 up_acc = _mm512_setzero_ps();
349 
350  int k = 0;
351  for (; k <= D - 16; k += 16) {
352  __m512 x_vec = _mm512_loadu_ps(&x[k]);
353  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
354  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
355 
356  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
357  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
358  }
359 
360  float gate = hsum512_ps(gate_acc);
361  float up = hsum512_ps(up_acc);
362 
363  for (; k < D; k++) {
364  gate += x[k] * wg_row[k];
365  up += x[k] * wu_row[k];
366  }
367 
368  if (b_gate) gate += b_gate[j];
369  if (b_up) up += b_up[j];
370 
371  swiglu[j] = silu_scalar(gate) * up;
372  }
373 
374  // Phase 2: Down projection (parallelize over D)
375  #pragma omp parallel for schedule(static)
376  for (int i = 0; i < D; i++) {
377  const float *wd_row = &W_down[i * Hff];
378 
379  __m512 acc = _mm512_setzero_ps();
380  int j = 0;
381  for (; j <= Hff - 16; j += 16) {
382  __m512 sw_vec = _mm512_loadu_ps(&swiglu[j]);
383  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
384  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
385  }
386 
387  float sum = hsum512_ps(acc);
388  for (; j < Hff; j++) {
389  sum += swiglu[j] * wd_row[j];
390  }
391 
392  output[i] = sum + (b_down ? b_down[i] : 0.0f);
393  }
394 
395 #else
396  // Scalar fallback with stack buffer
397  float swiglu[MAX_SWIGLU_STACK];
398 
399  for (int j = 0; j < Hff; j++) {
400  float gate = 0.0f, up = 0.0f;
401  for (int k = 0; k < D; k++) {
402  gate += x[k] * W_gate[j * D + k];
403  up += x[k] * W_up[j * D + k];
404  }
405  if (b_gate) gate += b_gate[j];
406  if (b_up) up += b_up[j];
407  swiglu[j] = silu_scalar(gate) * up;
408  }
409 
410  for (int i = 0; i < D; i++) {
411  float sum = 0.0f;
412  for (int j = 0; j < Hff; j++) {
413  sum += swiglu[j] * W_down[i * Hff + j];
414  }
415  output[i] = sum + (b_down ? b_down[i] : 0.0f);
416  }
417 #endif
418 }
419 
420 // =============================================================================
421 // Version 3: True Zero-Copy Tiled Fusion (Best for Large L2)
422 // =============================================================================
423 //
424 // This version processes tiles of the intermediate dimension and immediately
425 // accumulates into output, without any intermediate buffer allocation.
426 //
427 // Optimal for Xeon 5th gen with 2MB L2 per core.
428 //
430  const float *x, // [D]
431  const float *W_gate, // [Hff, D]
432  const float *W_up, // [Hff, D]
433  const float *W_down, // [D, Hff]
434  const float *b_gate, // [Hff] or NULL
435  const float *b_up, // [Hff] or NULL
436  const float *b_down, // [D] or NULL
437  float *output, // [D]
438  int D,
439  int Hff)
440 {
441  // Tile size chosen to fit in L2 with W_down tile
442  // Tile of swiglu: 256 floats = 1KB
443  // Tile of W_down: 256 * D floats = 256 * 896 * 4 = 896KB
444  // Fits in 2MB L2 with room for x and prefetch
445  const int TILE = 256;
446 
447 #if defined(__AVX512F__)
448  // Initialize output
449  #pragma omp parallel for schedule(static)
450  for (int i = 0; i < D; i++) {
451  output[i] = b_down ? b_down[i] : 0.0f;
452  }
453 
454  // Process tiles of intermediate dimension
455  for (int t = 0; t < Hff; t += TILE) {
456  int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
457  int tile_size = tile_end - t;
458 
459  // Compute swiglu tile
460  float swiglu_tile[256] __attribute__((aligned(64)));
461 
462  #pragma omp parallel for schedule(static)
463  for (int jj = 0; jj < tile_size; jj++) {
464  int j = t + jj;
465  const float *wg_row = &W_gate[j * D];
466  const float *wu_row = &W_up[j * D];
467 
468  __m512 gate_acc = _mm512_setzero_ps();
469  __m512 up_acc = _mm512_setzero_ps();
470 
471  int k = 0;
472  for (; k <= D - 16; k += 16) {
473  __m512 x_vec = _mm512_loadu_ps(&x[k]);
474  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
475  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
476 
477  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
478  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
479  }
480 
481  float gate = hsum512_ps(gate_acc);
482  float up = hsum512_ps(up_acc);
483 
484  for (; k < D; k++) {
485  gate += x[k] * wg_row[k];
486  up += x[k] * wu_row[k];
487  }
488 
489  if (b_gate) gate += b_gate[j];
490  if (b_up) up += b_up[j];
491 
492  swiglu_tile[jj] = silu_scalar(gate) * up;
493  }
494 
495  // Accumulate into output (parallelize over D)
496  #pragma omp parallel for schedule(static)
497  for (int i = 0; i < D; i++) {
498  const float *wd_row = &W_down[i * Hff + t];
499 
500  __m512 acc = _mm512_setzero_ps();
501  int j = 0;
502  for (; j <= tile_size - 16; j += 16) {
503  __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
504  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
505  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
506  }
507 
508  float sum = hsum512_ps(acc);
509  for (; j < tile_size; j++) {
510  sum += swiglu_tile[j] * wd_row[j];
511  }
512 
513  // Atomic add (or use thread-local buffers for better perf)
514  #pragma omp atomic
515  output[i] += sum;
516  }
517  }
518 
519 #else
520  // Scalar fallback
521  for (int i = 0; i < D; i++) {
522  output[i] = b_down ? b_down[i] : 0.0f;
523  }
524 
525  for (int t = 0; t < Hff; t += TILE) {
526  int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
527 
528  float swiglu_tile[256];
529 
530  for (int j = t; j < tile_end; j++) {
531  float gate = 0.0f, up = 0.0f;
532  for (int k = 0; k < D; k++) {
533  gate += x[k] * W_gate[j * D + k];
534  up += x[k] * W_up[j * D + k];
535  }
536  if (b_gate) gate += b_gate[j];
537  if (b_up) up += b_up[j];
538  swiglu_tile[j - t] = silu_scalar(gate) * up;
539  }
540 
541  for (int i = 0; i < D; i++) {
542  for (int j = t; j < tile_end; j++) {
543  output[i] += swiglu_tile[j - t] * W_down[i * Hff + j];
544  }
545  }
546  }
547 #endif
548 }
void fused_mlp_swiglu_decode_tiled(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void fused_mlp_swiglu_decode(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
#define MLP_TILE_SIZE
static float silu_scalar(float x)
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
#define MAX_SWIGLU_STACK
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)