Fused kernels for prefill phase with proper 2D tiling. More...
#include "ckernel_engine.h"#include "ckernel_quant.h"#include <math.h>#include <string.h>#include <stddef.h>#include <stdio.h>Go to the source code of this file.
Macros | |
| #define | PREFILL_TILE_M 64 |
| #define | PREFILL_TILE_N 256 |
Functions | |
| static void | add_bias_tile (float *out, const float *bias, int tile_m, int out_dim) |
| static size_t | align_up_size (size_t value, size_t align) |
| void | fused_mlp_swiglu_prefill (const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch) |
| Fused MLP (Gate + Up + SwiGLU + Down) for prefill. More... | |
| void | fused_mlp_swiglu_prefill_bias (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch) |
| Fused MLP for prefill with proper tiling. More... | |
| void | fused_mlp_swiglu_prefill_w1w2_quant (const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch) |
| Quantized fused MLP for prefill (W1=gate+up, W2=down) More... | |
| size_t | fused_mlp_swiglu_prefill_w1w2_quant_scratch_size (int aligned_embed_dim, int aligned_intermediate_dim) |
| Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant. More... | |
| size_t | fused_mlp_swiglu_scratch_size (int intermediate) |
| Get scratch size for fused MLP. More... | |
| static void | fused_rmsnorm_gemm_2d_tiled (const float *x, const float *gamma, const float *W, float *output, int seq_len, int hidden, int out_dim, float eps, float *x_norm_scratch) |
| Fused RMSNorm + single GEMM with 2D tiling (weight reuse) More... | |
| void | fused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch) |
| Fused RMSNorm + QKV projection for prefill (v3 optimized) More... | |
| void | fused_rmsnorm_qkv_prefill_head_major (const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch) |
| Fused RMSNorm + QKV projection for prefill (head-major outputs) More... | |
| void | fused_rmsnorm_qkv_prefill_head_major_quant (const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch) |
| Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations) More... | |
| size_t | fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size (int aligned_embed_dim) |
| Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant. More... | |
| size_t | fused_rmsnorm_qkv_scratch_size (int hidden) |
| Get scratch size for fused prefill. More... | |
| static void | gemm_nt_q8_0_dispatch (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt) |
| static void | gemm_nt_q8_0_mlp_dispatch (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt) |
| static void | gemm_nt_q8_k_mlp_dispatch (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt) |
| static void | gemm_nt_q8_k_qkv_dispatch (const void *A_q8k, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt) |
| static void | gemm_tile_nt_strided (const float *A, const float *B_tile, float *C, int tile_m, int tile_n, int K, int C_stride) |
| GEMM tile with N-dimension tiling (weight reuse) More... | |
| static int | mlp_q8_0_dtype_supported (CKDataType dt) |
| static int | mlp_q8_k_dtype_supported (CKDataType dt) |
| static int | qkv_q8_0_dtype_supported (CKDataType dt) |
| static int | qkv_q8_k_dtype_supported (CKDataType dt) |
| static void | rmsnorm_tile (const float *input, const float *gamma, float *output, int tile_m, int embed_dim, int aligned_embed_dim, float eps) |
| Compute RMSNorm for a tile of tokens. More... | |
| static float | silu_prefill (float x) |
| void | unfused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps) |
| Unfused version for comparison. More... | |
Fused kernels for prefill phase with proper 2D tiling.
After changes: make test && make llamacpp-parity-full
Naive M-dimension tiling (token tiles) causes weight reloading:
Correct approach: Tile along N (output/weight) dimension OUTER, M (token) dimension INNER. This way:
For C[M,N] = RMSNorm(A[M,K]) × B[N,K]^T:
for n_tile in [0, N, TILE_N]: # Outer: weight tiles load B[n_tile:n_tile+TILE_N, :] into L3 for m_tile in [0, M, TILE_M]: # Inner: token tiles x_norm = rmsnorm(A[m_tile]) # x_norm in L2 C[m_tile, n_tile] = x_norm × B_tile # Consumes B from L3
Cache behavior:
Definition in file prefill_fused_gemm.c.
| #define PREFILL_TILE_M 64 |
Definition at line 64 of file prefill_fused_gemm.c.
| #define PREFILL_TILE_N 256 |
Definition at line 65 of file prefill_fused_gemm.c.
|
static |
Definition at line 310 of file prefill_fused_gemm.c.
Referenced by fused_mlp_swiglu_prefill_bias(), and fused_rmsnorm_qkv_prefill_head_major().
|
static |
Definition at line 67 of file prefill_fused_gemm.c.
Referenced by fused_mlp_swiglu_prefill_w1w2_quant(), fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(), fused_rmsnorm_qkv_prefill_head_major_quant(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().
| void fused_mlp_swiglu_prefill | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| float * | output, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | intermediate, | ||
| float * | scratch | ||
| ) |
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
Tiles along token dimension to keep gate/up/hidden in L3 cache.
| scratch | Temporary buffer from fused_mlp_swiglu_scratch_size() |
Definition at line 879 of file prefill_fused_gemm.c.
References fused_mlp_swiglu_prefill_bias().
| void fused_mlp_swiglu_prefill_bias | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | B_gate, | ||
| const float * | B_up, | ||
| const float * | B_down, | ||
| float * | output, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | intermediate, | ||
| float * | scratch | ||
| ) |
Fused MLP for prefill with proper tiling.
Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.
Definition at line 746 of file prefill_fused_gemm.c.
References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and silu().
Referenced by fused_mlp_swiglu_prefill().
| void fused_mlp_swiglu_prefill_w1w2_quant | ( | const float * | x, |
| const void * | W1, | ||
| const float * | B1, | ||
| CKDataType | w1_dt, | ||
| const void * | W2, | ||
| const float * | B2, | ||
| CKDataType | w2_dt, | ||
| float * | output, | ||
| int | seq_len, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | intermediate_dim, | ||
| int | aligned_intermediate_dim, | ||
| void * | scratch | ||
| ) |
Quantized fused MLP for prefill (W1=gate+up, W2=down)
Uses Q8_0 activations for W1 (Q5_0/Q8_0 weights) and Q8_K activations for W2 (Q4_K/Q6_K weights).
Definition at line 965 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_mlp_dispatch(), gemm_nt_q8_k_mlp_dispatch(), mlp_q8_0_dtype_supported(), mlp_q8_k_dtype_supported(), PREFILL_TILE_M, quantize_row_q8_0(), quantize_row_q8_k(), and silu_prefill().
Referenced by mega_fused_outproj_mlp_prefill().
| size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size | ( | int | aligned_embed_dim, |
| int | aligned_intermediate_dim | ||
| ) |
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
Definition at line 1063 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.
Referenced by mega_fused_outproj_mlp_prefill_scratch_size().
| size_t fused_mlp_swiglu_scratch_size | ( | int | intermediate | ) |
Get scratch size for fused MLP.
Get scratch buffer size for fused_mlp_swiglu_prefill.
Definition at line 899 of file prefill_fused_gemm.c.
References PREFILL_TILE_M.
|
static |
Fused RMSNorm + single GEMM with 2D tiling (weight reuse)
Tiles along N (weights) OUTER, M (tokens) INNER. Weight tiles are reused across all token tiles.
Definition at line 332 of file prefill_fused_gemm.c.
References gemm_tile_nt_strided(), PREFILL_TILE_M, PREFILL_TILE_N, and rmsnorm_tile().
| void fused_rmsnorm_qkv_prefill | ( | const float * | x, |
| const float * | gamma, | ||
| const float * | Wq, | ||
| const float * | Wk, | ||
| const float * | Wv, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | q_dim, | ||
| int | kv_dim, | ||
| float | eps, | ||
| float * | scratch | ||
| ) |
Fused RMSNorm + QKV projection for prefill (v3 optimized)
Fused RMSNorm + QKV projection for prefill.
KEY INSIGHT: For Qwen2-0.5B, all QKV weights fit in L3: Wq (896×896) + Wk (128×896) + Wv (128×896) = 4.1MB < 6MB L3
So we use M-tiling (tokens) only:
This avoids both:
Definition at line 393 of file prefill_fused_gemm.c.
References gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().
| void fused_rmsnorm_qkv_prefill_head_major | ( | const float * | x, |
| const float * | gamma, | ||
| const float * | Wq, | ||
| const float * | Bq, | ||
| const float * | Wk, | ||
| const float * | Bk, | ||
| const float * | Wv, | ||
| const float * | Bv, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| float | eps, | ||
| float * | scratch | ||
| ) |
Fused RMSNorm + QKV projection for prefill (head-major outputs)
Q is written as [num_heads, seq_len, aligned_head_dim]. K/V are written with kv_stride_tokens for KV-cache compatibility.
Definition at line 441 of file prefill_fused_gemm.c.
References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().
Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().
| void fused_rmsnorm_qkv_prefill_head_major_quant | ( | const float * | x, |
| const float * | gamma, | ||
| const void * | Wq, | ||
| const float * | Bq, | ||
| CKDataType | wq_dt, | ||
| const void * | Wk, | ||
| const float * | Bk, | ||
| CKDataType | wk_dt, | ||
| const void * | Wv, | ||
| const float * | Bv, | ||
| CKDataType | wv_dt, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
Supports Q5_0 or Q8_0 weights with Q8_0 activations. Writes K/V directly into KV cache layout (kv_stride_tokens).
Definition at line 519 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_dispatch(), gemm_nt_q8_k_qkv_dispatch(), PREFILL_TILE_M, qkv_q8_0_dtype_supported(), qkv_q8_k_dtype_supported(), quantize_row_q8_0(), quantize_row_q8_k(), and rmsnorm_tile().
Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().
| size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size | ( | int | aligned_embed_dim | ) |
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
Definition at line 651 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.
Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), mega_fused_attention_prefill_q8_0_scratch_size(), and mega_fused_attention_prefill_scratch_size().
| size_t fused_rmsnorm_qkv_scratch_size | ( | int | hidden | ) |
Get scratch size for fused prefill.
Get scratch buffer size for fused_rmsnorm_qkv_prefill.
Definition at line 739 of file prefill_fused_gemm.c.
References PREFILL_TILE_M.
|
static |
Definition at line 177 of file prefill_fused_gemm.c.
References C, CK_DT_Q5_0, CK_DT_Q8_0, gemm_nt_q5_0_q8_0(), and gemm_nt_q8_0_q8_0().
Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().
|
static |
Definition at line 917 of file prefill_fused_gemm.c.
References C, CK_DT_Q5_0, CK_DT_Q8_0, gemm_nt_q5_0_q8_0(), and gemm_nt_q8_0_q8_0().
Referenced by fused_mlp_swiglu_prefill_w1w2_quant().
|
static |
Definition at line 938 of file prefill_fused_gemm.c.
References C, CK_DT_Q4_K, CK_DT_Q6_K, gemm_nt_q4_k_q8_k(), and gemm_nt_q6_k_q8_k().
Referenced by fused_mlp_swiglu_prefill_w1w2_quant().
|
static |
Definition at line 198 of file prefill_fused_gemm.c.
References C, CK_DT_Q4_K, CK_DT_Q6_K, gemm_nt_q4_k_q8_k(), and gemm_nt_q6_k_q8_k().
Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().
|
static |
GEMM tile with N-dimension tiling (weight reuse)
Computes: C[tile_m × tile_n] = A[tile_m × K] × B[tile_n × K]^T where B_tile is a slice of rows from the weight matrix.
Uses MKL if available for optimal performance.
| A | Input tile [tile_m × K] |
| B_tile | Weight tile [tile_n × K] (transposed layout) |
| C | Output tile [tile_m × tile_n] (column slice of full output) |
| C_stride | Stride between rows of C (= full N dimension) |
Definition at line 236 of file prefill_fused_gemm.c.
References C.
Referenced by fused_mlp_swiglu_prefill_bias(), fused_rmsnorm_gemm_2d_tiled(), fused_rmsnorm_qkv_prefill(), fused_rmsnorm_qkv_prefill_head_major(), and unfused_rmsnorm_qkv_prefill().
|
static |
Definition at line 909 of file prefill_fused_gemm.c.
References CK_DT_Q5_0, and CK_DT_Q8_0.
Referenced by fused_mlp_swiglu_prefill_w1w2_quant().
|
static |
Definition at line 913 of file prefill_fused_gemm.c.
References CK_DT_Q4_K, and CK_DT_Q6_K.
Referenced by fused_mlp_swiglu_prefill_w1w2_quant().
|
static |
Definition at line 169 of file prefill_fused_gemm.c.
References CK_DT_Q5_0, and CK_DT_Q8_0.
Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().
|
static |
Definition at line 173 of file prefill_fused_gemm.c.
References CK_DT_Q4_K, and CK_DT_Q6_K.
Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().
|
static |
Compute RMSNorm for a tile of tokens.
Definition at line 86 of file prefill_fused_gemm.c.
Referenced by fused_rmsnorm_gemm_2d_tiled(), fused_rmsnorm_qkv_prefill(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), and unfused_rmsnorm_qkv_prefill().
|
inlinestatic |
Definition at line 905 of file prefill_fused_gemm.c.
Referenced by fused_mlp_swiglu_prefill_w1w2_quant().
| void unfused_rmsnorm_qkv_prefill | ( | const float * | x, |
| const float * | gamma, | ||
| const float * | Wq, | ||
| const float * | Wk, | ||
| const float * | Wv, | ||
| float * | x_norm, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | q_dim, | ||
| int | kv_dim, | ||
| float | eps | ||
| ) |
Unfused version for comparison.
Unfused version for benchmarking comparison.
Definition at line 667 of file prefill_fused_gemm.c.
References gemm_tile_nt_strided(), PREFILL_TILE_M, PREFILL_TILE_N, and rmsnorm_tile().