Fused RMSNorm + Linear (GEMV) kernel. More...
#include <assert.h>#include <math.h>#include <stddef.h>#include <stdint.h>#include <string.h>#include "ckernel_quant.h"Go to the source code of this file.
Functions | |
| static int | ck_nearest_int_fused (float fval) |
| void | fused_rmsnorm_linear_q4k (float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps) |
| Fused RMSNorm + Q4_K Linear projection. More... | |
| void | gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K) |
| void | unfused_rmsnorm_linear_q4k_ref (float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps) |
| Reference (unfused) implementation for correctness testing. More... | |
Fused RMSNorm + Linear (GEMV) kernel.
After changes: make test && make llamacpp-parity-full
VIOLATION: Has free() calls and memcpy in test/benchmark code at end of file. TODO: Move test code to unittest/, remove free()/memcpy from kernel file.
Unfused: RMSNorm(x) → [DRAM write: norm_out] → Quantize → [DRAM write: q8] → GEMV Total DRAM: 2 writes + 2 reads = 4 * hidden_size bytes
Fused: RMSNorm(x) → [registers] → Quantize → [stack/L1: q8] → GEMV Total DRAM: 0 intermediate writes/reads
Expected: 2-4x memory traffic reduction for this operation
Definition in file fused_rmsnorm_linear.c.
|
inlinestatic |
| void fused_rmsnorm_linear_q4k | ( | float * | y, |
| const float * | x, | ||
| const float * | gamma, | ||
| const void * | W_q4k, | ||
| int | M, | ||
| int | K, | ||
| float | eps | ||
| ) |
Fused RMSNorm + Q4_K Linear projection.
Computes: y = Linear(RMSNorm(x)) where Linear uses Q4_K weights and Q8_K activations internally.
The key optimization is that the normalized values never touch DRAM - they go directly from RMSNorm computation to Q8_K quantization to GEMV.
| y | Output (FP32), shape [M] |
| x | Input hidden state (FP32), shape [K] |
| gamma | RMSNorm scale weights (FP32), shape [K] |
| W_q4k | Linear weights in Q4_K format, shape [M, K] |
| M | Output dimension (e.g., 3 * hidden for QKV) |
| K | Input dimension (hidden_size) |
| eps | RMSNorm epsilon (typically 1e-5 or 1e-6) |
Definition at line 83 of file fused_rmsnorm_linear.c.
References block_q8_K::bsums, ck_nearest_int_fused(), block_q8_K::d, gemv_q4_k_q8_k(), hsum256_ps_fused(), QK_K, and block_q8_K::qs.
| void gemv_q4_k_q8_k | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 239 of file gemm_kernels_q4k_q8k.c.
Referenced by fused_rmsnorm_linear_q4k(), gemm_q4_k_q8_k(), and unfused_rmsnorm_linear_q4k_ref().
| void unfused_rmsnorm_linear_q4k_ref | ( | float * | y, |
| const float * | x, | ||
| const float * | gamma, | ||
| const void * | W_q4k, | ||
| int | M, | ||
| int | K, | ||
| float | eps | ||
| ) |
Reference (unfused) implementation for correctness testing.
This is the SLOW version that does separate RMSNorm and GEMV calls, with intermediate results going to DRAM.
Definition at line 266 of file fused_rmsnorm_linear.c.
References gemv_q4_k_q8_k(), QK_K, and quantize_row_q8_k().