Fully fused MLP decode kernel (T=1 token generation) More...
Go to the source code of this file.
Macros | |
| #define | MAX_SWIGLU_STACK 8192 |
| #define | MLP_TILE_SIZE 64 |
| #define | OUTPUT_TILE_SIZE 32 |
Functions | |
| void | fused_mlp_swiglu_decode (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff) |
| void | fused_mlp_swiglu_decode_tiled (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff) |
| void | fused_mlp_swiglu_decode_v2 (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff) |
| static float | silu_scalar (float x) |
Fully fused MLP decode kernel (T=1 token generation)
After changes: make test && make llamacpp-parity-full
LEGACY: This file is from v6/v6.5 and kept for backward compatibility.
This kernel fuses the ENTIRE MLP block into a single pass: output = Down(SwiGLU(Gate(x), Up(x))) + residual
Key optimization: The intermediate SwiGLU values (~4864 floats = 19KB for Qwen2) NEVER touch DRAM. They stay in L1/L2 cache through tiling.
Target: Intel Xeon 5th Gen (Emerald Rapids) with AVX-512 and AMX
Memory traffic comparison (Qwen2-0.5B, D=896, Hff=4864): Unfused: 76 KB activation traffic (38KB write + 38KB read) Fused: 0 KB activation traffic (tiles stay in L1)
Weight layout expected: Row-major, transposed for matvec W_gate[Hff, D], W_up[Hff, D], W_down[D, Hff]
Definition in file mlp_fused_decode.c.
| #define MAX_SWIGLU_STACK 8192 |
Definition at line 316 of file mlp_fused_decode.c.
| #define MLP_TILE_SIZE 64 |
Definition at line 52 of file mlp_fused_decode.c.
| #define OUTPUT_TILE_SIZE 32 |
Definition at line 55 of file mlp_fused_decode.c.
| void fused_mlp_swiglu_decode | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| const float * | b_down, | ||
| float * | output, | ||
| int | D, | ||
| int | Hff | ||
| ) |
Definition at line 154 of file mlp_fused_decode.c.
References __attribute__(), MLP_TILE_SIZE, and silu_scalar().
| void fused_mlp_swiglu_decode_tiled | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| const float * | b_down, | ||
| float * | output, | ||
| int | D, | ||
| int | Hff | ||
| ) |
Definition at line 429 of file mlp_fused_decode.c.
References __attribute__(), and silu_scalar().
Referenced by fused_mlp_swiglu_decode_v2().
| void fused_mlp_swiglu_decode_v2 | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| const float * | b_down, | ||
| float * | output, | ||
| int | D, | ||
| int | Hff | ||
| ) |
Definition at line 318 of file mlp_fused_decode.c.
References __attribute__(), fused_mlp_swiglu_decode_tiled(), MAX_SWIGLU_STACK, and silu_scalar().
Referenced by ck_mlp_swiglu_forward_fully_fused_token().
|
inlinestatic |
Definition at line 134 of file mlp_fused_decode.c.
Referenced by fused_mlp_swiglu_decode(), fused_mlp_swiglu_decode_tiled(), and fused_mlp_swiglu_decode_v2().