← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mlp_fused_decode.c File Reference

Fully fused MLP decode kernel (T=1 token generation) More...

#include "ckernel_engine.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>

Go to the source code of this file.

Macros

#define MAX_SWIGLU_STACK   8192
 
#define MLP_TILE_SIZE   64
 
#define OUTPUT_TILE_SIZE   32
 

Functions

void fused_mlp_swiglu_decode (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
 
void fused_mlp_swiglu_decode_tiled (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
 
void fused_mlp_swiglu_decode_v2 (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
 
static float silu_scalar (float x)
 

Detailed Description

Fully fused MLP decode kernel (T=1 token generation)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

LEGACY: This file is from v6/v6.5 and kept for backward compatibility.

This kernel fuses the ENTIRE MLP block into a single pass: output = Down(SwiGLU(Gate(x), Up(x))) + residual

Key optimization: The intermediate SwiGLU values (~4864 floats = 19KB for Qwen2) NEVER touch DRAM. They stay in L1/L2 cache through tiling.

Target: Intel Xeon 5th Gen (Emerald Rapids) with AVX-512 and AMX

Memory traffic comparison (Qwen2-0.5B, D=896, Hff=4864): Unfused: 76 KB activation traffic (38KB write + 38KB read) Fused: 0 KB activation traffic (tiles stay in L1)

Weight layout expected: Row-major, transposed for matvec W_gate[Hff, D], W_up[Hff, D], W_down[D, Hff]

Definition in file mlp_fused_decode.c.

Macro Definition Documentation

◆ MAX_SWIGLU_STACK

#define MAX_SWIGLU_STACK   8192

Definition at line 316 of file mlp_fused_decode.c.

◆ MLP_TILE_SIZE

#define MLP_TILE_SIZE   64

Definition at line 52 of file mlp_fused_decode.c.

◆ OUTPUT_TILE_SIZE

#define OUTPUT_TILE_SIZE   32

Definition at line 55 of file mlp_fused_decode.c.

Function Documentation

◆ fused_mlp_swiglu_decode()

void fused_mlp_swiglu_decode ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  b_gate,
const float *  b_up,
const float *  b_down,
float *  output,
int  D,
int  Hff 
)

Definition at line 154 of file mlp_fused_decode.c.

165 {
166 #if defined(__AVX512F__)
167  // Initialize output with bias or zero
168  if (b_down) {
169  memcpy(output, b_down, D * sizeof(float));
170  } else {
171  memset(output, 0, D * sizeof(float));
172  }
173 
174  // Process intermediate dimension in tiles
175  // Each tile computes MLP_TILE_SIZE swiglu values and immediately
176  // accumulates them into the output
177 
178  /* Bounds check for stack allocation */
179  if (D > 4096) return;
180 
181  #pragma omp parallel
182  {
183  /* Thread-local accumulator on stack (no malloc!) */
184  float local_output[4096] __attribute__((aligned(64)));
185  memset(local_output, 0, D * sizeof(float));
186 
187  #pragma omp for schedule(static)
188  for (int t = 0; t < Hff; t += MLP_TILE_SIZE) {
189  int tile_end = (t + MLP_TILE_SIZE < Hff) ? t + MLP_TILE_SIZE : Hff;
190  int tile_size = tile_end - t;
191 
192  // Compute SwiGLU for this tile (stays in L1 cache)
193  float swiglu_tile[MLP_TILE_SIZE] __attribute__((aligned(64)));
194 
195  for (int j = t; j < tile_end; j++) {
196  const float *wg_row = &W_gate[j * D];
197  const float *wu_row = &W_up[j * D];
198 
199  // Compute gate = x @ W_gate[j] using AVX-512
200  __m512 gate_acc = _mm512_setzero_ps();
201  __m512 up_acc = _mm512_setzero_ps();
202 
203  int k = 0;
204  for (; k <= D - 16; k += 16) {
205  __m512 x_vec = _mm512_loadu_ps(&x[k]);
206  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
207  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
208 
209  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
210  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
211  }
212 
213  float gate = hsum512_ps(gate_acc);
214  float up = hsum512_ps(up_acc);
215 
216  // Scalar remainder
217  for (; k < D; k++) {
218  gate += x[k] * wg_row[k];
219  up += x[k] * wu_row[k];
220  }
221 
222  // Add biases
223  if (b_gate) gate += b_gate[j];
224  if (b_up) up += b_up[j];
225 
226  // SwiGLU: SiLU(gate) * up
227  swiglu_tile[j - t] = silu_scalar(gate) * up;
228  }
229 
230  // Accumulate into output via W_down
231  // output[i] += sum_j(swiglu_tile[j] * W_down[i, t+j])
232  for (int i = 0; i < D; i++) {
233  const float *wd_row = &W_down[i * Hff + t];
234 
235  __m512 acc = _mm512_setzero_ps();
236  int j = 0;
237  for (; j <= tile_size - 16; j += 16) {
238  __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
239  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
240  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
241  }
242 
243  float sum = hsum512_ps(acc);
244  for (; j < tile_size; j++) {
245  sum += swiglu_tile[j] * wd_row[j];
246  }
247 
248  local_output[i] += sum;
249  }
250  }
251 
252  // Reduce thread-local outputs
253  #pragma omp critical
254  {
255  for (int i = 0; i < D; i++) {
256  output[i] += local_output[i];
257  }
258  }
259  /* No free - stack buffer auto-deallocates */
260  }
261 
262 #else
263  // Scalar fallback (same algorithm, no SIMD)
264  if (b_down) {
265  memcpy(output, b_down, D * sizeof(float));
266  } else {
267  memset(output, 0, D * sizeof(float));
268  }
269 
270  for (int t = 0; t < Hff; t += MLP_TILE_SIZE) {
271  int tile_end = (t + MLP_TILE_SIZE < Hff) ? t + MLP_TILE_SIZE : Hff;
272  int tile_size = tile_end - t;
273 
274  float swiglu_tile[MLP_TILE_SIZE];
275 
276  for (int j = t; j < tile_end; j++) {
277  float gate = 0.0f;
278  float up = 0.0f;
279 
280  for (int k = 0; k < D; k++) {
281  gate += x[k] * W_gate[j * D + k];
282  up += x[k] * W_up[j * D + k];
283  }
284 
285  if (b_gate) gate += b_gate[j];
286  if (b_up) up += b_up[j];
287 
288  swiglu_tile[j - t] = silu_scalar(gate) * up;
289  }
290 
291  for (int i = 0; i < D; i++) {
292  for (int j = 0; j < tile_size; j++) {
293  output[i] += swiglu_tile[j] * W_down[i * Hff + t + j];
294  }
295  }
296  }
297 #endif
298 }
#define MLP_TILE_SIZE
static float silu_scalar(float x)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)

References __attribute__(), MLP_TILE_SIZE, and silu_scalar().

◆ fused_mlp_swiglu_decode_tiled()

void fused_mlp_swiglu_decode_tiled ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  b_gate,
const float *  b_up,
const float *  b_down,
float *  output,
int  D,
int  Hff 
)

Definition at line 429 of file mlp_fused_decode.c.

440 {
441  // Tile size chosen to fit in L2 with W_down tile
442  // Tile of swiglu: 256 floats = 1KB
443  // Tile of W_down: 256 * D floats = 256 * 896 * 4 = 896KB
444  // Fits in 2MB L2 with room for x and prefetch
445  const int TILE = 256;
446 
447 #if defined(__AVX512F__)
448  // Initialize output
449  #pragma omp parallel for schedule(static)
450  for (int i = 0; i < D; i++) {
451  output[i] = b_down ? b_down[i] : 0.0f;
452  }
453 
454  // Process tiles of intermediate dimension
455  for (int t = 0; t < Hff; t += TILE) {
456  int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
457  int tile_size = tile_end - t;
458 
459  // Compute swiglu tile
460  float swiglu_tile[256] __attribute__((aligned(64)));
461 
462  #pragma omp parallel for schedule(static)
463  for (int jj = 0; jj < tile_size; jj++) {
464  int j = t + jj;
465  const float *wg_row = &W_gate[j * D];
466  const float *wu_row = &W_up[j * D];
467 
468  __m512 gate_acc = _mm512_setzero_ps();
469  __m512 up_acc = _mm512_setzero_ps();
470 
471  int k = 0;
472  for (; k <= D - 16; k += 16) {
473  __m512 x_vec = _mm512_loadu_ps(&x[k]);
474  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
475  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
476 
477  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
478  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
479  }
480 
481  float gate = hsum512_ps(gate_acc);
482  float up = hsum512_ps(up_acc);
483 
484  for (; k < D; k++) {
485  gate += x[k] * wg_row[k];
486  up += x[k] * wu_row[k];
487  }
488 
489  if (b_gate) gate += b_gate[j];
490  if (b_up) up += b_up[j];
491 
492  swiglu_tile[jj] = silu_scalar(gate) * up;
493  }
494 
495  // Accumulate into output (parallelize over D)
496  #pragma omp parallel for schedule(static)
497  for (int i = 0; i < D; i++) {
498  const float *wd_row = &W_down[i * Hff + t];
499 
500  __m512 acc = _mm512_setzero_ps();
501  int j = 0;
502  for (; j <= tile_size - 16; j += 16) {
503  __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
504  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
505  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
506  }
507 
508  float sum = hsum512_ps(acc);
509  for (; j < tile_size; j++) {
510  sum += swiglu_tile[j] * wd_row[j];
511  }
512 
513  // Atomic add (or use thread-local buffers for better perf)
514  #pragma omp atomic
515  output[i] += sum;
516  }
517  }
518 
519 #else
520  // Scalar fallback
521  for (int i = 0; i < D; i++) {
522  output[i] = b_down ? b_down[i] : 0.0f;
523  }
524 
525  for (int t = 0; t < Hff; t += TILE) {
526  int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
527 
528  float swiglu_tile[256];
529 
530  for (int j = t; j < tile_end; j++) {
531  float gate = 0.0f, up = 0.0f;
532  for (int k = 0; k < D; k++) {
533  gate += x[k] * W_gate[j * D + k];
534  up += x[k] * W_up[j * D + k];
535  }
536  if (b_gate) gate += b_gate[j];
537  if (b_up) up += b_up[j];
538  swiglu_tile[j - t] = silu_scalar(gate) * up;
539  }
540 
541  for (int i = 0; i < D; i++) {
542  for (int j = t; j < tile_end; j++) {
543  output[i] += swiglu_tile[j - t] * W_down[i * Hff + j];
544  }
545  }
546  }
547 #endif
548 }

References __attribute__(), and silu_scalar().

Referenced by fused_mlp_swiglu_decode_v2().

◆ fused_mlp_swiglu_decode_v2()

void fused_mlp_swiglu_decode_v2 ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  b_gate,
const float *  b_up,
const float *  b_down,
float *  output,
int  D,
int  Hff 
)

Definition at line 318 of file mlp_fused_decode.c.

329 {
330  // For large Hff, use tiled version to avoid stack overflow
331  if (Hff > MAX_SWIGLU_STACK) {
332  fused_mlp_swiglu_decode_tiled(x, W_gate, W_up, W_down,
333  b_gate, b_up, b_down, output, D, Hff);
334  return;
335  }
336 
337 #if defined(__AVX512F__)
338  // Stack-allocated swiglu buffer (max 32KB)
339  float swiglu[MAX_SWIGLU_STACK] __attribute__((aligned(64)));
340 
341  // Phase 1: Compute all swiglu values (parallelize over Hff)
342  #pragma omp parallel for schedule(static)
343  for (int j = 0; j < Hff; j++) {
344  const float *wg_row = &W_gate[j * D];
345  const float *wu_row = &W_up[j * D];
346 
347  __m512 gate_acc = _mm512_setzero_ps();
348  __m512 up_acc = _mm512_setzero_ps();
349 
350  int k = 0;
351  for (; k <= D - 16; k += 16) {
352  __m512 x_vec = _mm512_loadu_ps(&x[k]);
353  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
354  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
355 
356  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
357  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
358  }
359 
360  float gate = hsum512_ps(gate_acc);
361  float up = hsum512_ps(up_acc);
362 
363  for (; k < D; k++) {
364  gate += x[k] * wg_row[k];
365  up += x[k] * wu_row[k];
366  }
367 
368  if (b_gate) gate += b_gate[j];
369  if (b_up) up += b_up[j];
370 
371  swiglu[j] = silu_scalar(gate) * up;
372  }
373 
374  // Phase 2: Down projection (parallelize over D)
375  #pragma omp parallel for schedule(static)
376  for (int i = 0; i < D; i++) {
377  const float *wd_row = &W_down[i * Hff];
378 
379  __m512 acc = _mm512_setzero_ps();
380  int j = 0;
381  for (; j <= Hff - 16; j += 16) {
382  __m512 sw_vec = _mm512_loadu_ps(&swiglu[j]);
383  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
384  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
385  }
386 
387  float sum = hsum512_ps(acc);
388  for (; j < Hff; j++) {
389  sum += swiglu[j] * wd_row[j];
390  }
391 
392  output[i] = sum + (b_down ? b_down[i] : 0.0f);
393  }
394 
395 #else
396  // Scalar fallback with stack buffer
397  float swiglu[MAX_SWIGLU_STACK];
398 
399  for (int j = 0; j < Hff; j++) {
400  float gate = 0.0f, up = 0.0f;
401  for (int k = 0; k < D; k++) {
402  gate += x[k] * W_gate[j * D + k];
403  up += x[k] * W_up[j * D + k];
404  }
405  if (b_gate) gate += b_gate[j];
406  if (b_up) up += b_up[j];
407  swiglu[j] = silu_scalar(gate) * up;
408  }
409 
410  for (int i = 0; i < D; i++) {
411  float sum = 0.0f;
412  for (int j = 0; j < Hff; j++) {
413  sum += swiglu[j] * W_down[i * Hff + j];
414  }
415  output[i] = sum + (b_down ? b_down[i] : 0.0f);
416  }
417 #endif
418 }
void fused_mlp_swiglu_decode_tiled(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
#define MAX_SWIGLU_STACK

References __attribute__(), fused_mlp_swiglu_decode_tiled(), MAX_SWIGLU_STACK, and silu_scalar().

Referenced by ck_mlp_swiglu_forward_fully_fused_token().

◆ silu_scalar()

static float silu_scalar ( float  x)
inlinestatic

Definition at line 134 of file mlp_fused_decode.c.

134  {
135  return x / (1.0f + expf(-x));
136 }

Referenced by fused_mlp_swiglu_decode(), fused_mlp_swiglu_decode_tiled(), and fused_mlp_swiglu_decode_v2().