← Back to C-Kernel-Engine Docs Doxygen Source Documentation
fused_rmsnorm_linear.c File Reference

Fused RMSNorm + Linear (GEMV) kernel. More...

#include <assert.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

static int ck_nearest_int_fused (float fval)
 
void fused_rmsnorm_linear_q4k (float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps)
 Fused RMSNorm + Q4_K Linear projection. More...
 
void gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K)
 
void unfused_rmsnorm_linear_q4k_ref (float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps)
 Reference (unfused) implementation for correctness testing. More...
 

Detailed Description

Fused RMSNorm + Linear (GEMV) kernel.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. NO memcpy for layout - use strided access, not copies
  4. API must define: inputs, outputs, workspace, and memory layouts
  5. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

VIOLATION: Has free() calls and memcpy in test/benchmark code at end of file. TODO: Move test code to unittest/, remove free()/memcpy from kernel file.

FUSION BENEFIT:

Unfused: RMSNorm(x) → [DRAM write: norm_out] → Quantize → [DRAM write: q8] → GEMV Total DRAM: 2 writes + 2 reads = 4 * hidden_size bytes

Fused: RMSNorm(x) → [registers] → Quantize → [stack/L1: q8] → GEMV Total DRAM: 0 intermediate writes/reads

Expected: 2-4x memory traffic reduction for this operation

Definition in file fused_rmsnorm_linear.c.

Function Documentation

◆ ck_nearest_int_fused()

static int ck_nearest_int_fused ( float  fval)
inlinestatic

Definition at line 48 of file fused_rmsnorm_linear.c.

48  {
49  float val = fval + 12582912.f;
50  int i;
51  memcpy(&i, &val, sizeof(int));
52  return (i & 0x007fffff) - 0x00400000;
53 }

Referenced by fused_rmsnorm_linear_q4k().

◆ fused_rmsnorm_linear_q4k()

void fused_rmsnorm_linear_q4k ( float *  y,
const float *  x,
const float *  gamma,
const void *  W_q4k,
int  M,
int  K,
float  eps 
)

Fused RMSNorm + Q4_K Linear projection.

Computes: y = Linear(RMSNorm(x)) where Linear uses Q4_K weights and Q8_K activations internally.

The key optimization is that the normalized values never touch DRAM - they go directly from RMSNorm computation to Q8_K quantization to GEMV.

Parameters
yOutput (FP32), shape [M]
xInput hidden state (FP32), shape [K]
gammaRMSNorm scale weights (FP32), shape [K]
W_q4kLinear weights in Q4_K format, shape [M, K]
MOutput dimension (e.g., 3 * hidden for QKV)
KInput dimension (hidden_size)
epsRMSNorm epsilon (typically 1e-5 or 1e-6)

Definition at line 83 of file fused_rmsnorm_linear.c.

89 {
90  if (!y || !x || !gamma || !W_q4k || M <= 0 || K <= 0) {
91  return;
92  }
93 
94  assert(K % QK_K == 0);
95  const int nb = K / QK_K; /* Number of Q8_K blocks */
96 
97  /* Stack-allocated Q8_K buffer - stays in L1/L2 cache */
98  /* Max supported K = 8192 (8 blocks of 256) */
99  block_q8_K q8_buffer[32]; /* 32 * ~260 bytes = ~8KB on stack */
100  assert(nb <= 32 && "K too large for stack buffer");
101 
102  /* ================================================================
103  * PHASE 1: Compute RMSNorm and quantize to Q8_K
104  * Result stays in stack (L1/L2), never touches DRAM
105  * ================================================================ */
106 
107 #if defined(__AVX512F__)
108  /* AVX-512: Compute sum of squares */
109  __m512 sum_sq_vec = _mm512_setzero_ps();
110  int d = 0;
111  for (; d + 16 <= K; d += 16) {
112  __m512 xv = _mm512_loadu_ps(&x[d]);
113  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
114  }
115  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
116  for (; d < K; ++d) {
117  sum_sq += x[d] * x[d];
118  }
119 
120 #elif defined(__AVX__)
121  /* AVX: Compute sum of squares */
122  __m256 sum_sq_vec = _mm256_setzero_ps();
123  int d = 0;
124  for (; d + 8 <= K; d += 8) {
125  __m256 xv = _mm256_loadu_ps(&x[d]);
126  __m256 xv_sq = _mm256_mul_ps(xv, xv);
127  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
128  }
129  float sum_sq = hsum256_ps_fused(sum_sq_vec);
130  for (; d < K; ++d) {
131  sum_sq += x[d] * x[d];
132  }
133 
134 #else
135  /* Scalar fallback */
136  double sum_sq = 0.0;
137  for (int d = 0; d < K; ++d) {
138  double v = (double)x[d];
139  sum_sq += v * v;
140  }
141 #endif
142 
143  float mean_sq = (float)sum_sq / (float)K;
144  float rstd = 1.0f / sqrtf(mean_sq + eps);
145 
146  /* ================================================================
147  * PHASE 2: Apply RMSNorm and quantize to Q8_K in one pass
148  * Normalized values go directly to Q8_K blocks
149  * ================================================================ */
150 
151  for (int i = 0; i < nb; ++i) {
152  const float *x_block = x + i * QK_K;
153  const float *g_block = gamma + i * QK_K;
154 
155  /* Find max absolute value for this block's normalized output */
156  float max_val = 0.0f;
157  float amax = 0.0f;
158 
159 #if defined(__AVX512F__)
160  __m512 rstd_vec = _mm512_set1_ps(rstd);
161  __m512 max_vec = _mm512_setzero_ps();
162  __m512 sign_mask = _mm512_set1_ps(-0.0f);
163 
164  for (int j = 0; j < QK_K; j += 16) {
165  __m512 xv = _mm512_loadu_ps(&x_block[j]);
166  __m512 gv = _mm512_loadu_ps(&g_block[j]);
167  __m512 norm = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
168  __m512 abs_norm = _mm512_andnot_ps(sign_mask, norm);
169  max_vec = _mm512_max_ps(max_vec, abs_norm);
170 
171  /* Track max with sign for scale computation */
172  __mmask16 gt_mask = _mm512_cmp_ps_mask(abs_norm, _mm512_set1_ps(amax), _CMP_GT_OQ);
173  if (gt_mask) {
174  float temp_amax = _mm512_reduce_max_ps(abs_norm);
175  if (temp_amax > amax) {
176  amax = temp_amax;
177  /* Find the actual max value with sign */
178  for (int k = 0; k < 16; ++k) {
179  float v = x_block[j + k] * rstd * g_block[j + k];
180  if (fabsf(v) >= amax - 1e-6f) {
181  max_val = v;
182  break;
183  }
184  }
185  }
186  }
187  }
188  amax = _mm512_reduce_max_ps(max_vec);
189 
190 #elif defined(__AVX__)
191  __m256 rstd_vec = _mm256_set1_ps(rstd);
192 
193  for (int j = 0; j < QK_K; j += 8) {
194  __m256 xv = _mm256_loadu_ps(&x_block[j]);
195  __m256 gv = _mm256_loadu_ps(&g_block[j]);
196  __m256 norm = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
197 
198  /* Check each element for max */
199  float norm_arr[8];
200  _mm256_storeu_ps(norm_arr, norm);
201  for (int k = 0; k < 8; ++k) {
202  float av = fabsf(norm_arr[k]);
203  if (av > amax) {
204  amax = av;
205  max_val = norm_arr[k];
206  }
207  }
208  }
209 
210 #else
211  for (int j = 0; j < QK_K; ++j) {
212  float norm = x_block[j] * rstd * g_block[j];
213  float av = fabsf(norm);
214  if (av > amax) {
215  amax = av;
216  max_val = norm;
217  }
218  }
219 #endif
220 
221  /* Handle zero block */
222  if (amax < 1e-10f) {
223  q8_buffer[i].d = 0.0f;
224  memset(q8_buffer[i].qs, 0, sizeof(q8_buffer[i].qs));
225  memset(q8_buffer[i].bsums, 0, sizeof(q8_buffer[i].bsums));
226  continue;
227  }
228 
229  /* Compute scale and quantize */
230  const float iscale = -127.0f / max_val;
231  q8_buffer[i].d = 1.0f / iscale;
232 
233  /* Quantize and compute bsums */
234  for (int j = 0; j < QK_K; ++j) {
235  float norm = x_block[j] * rstd * g_block[j];
236  int v = ck_nearest_int_fused(iscale * norm);
237  v = (v > 127) ? 127 : ((v < -128) ? -128 : v);
238  q8_buffer[i].qs[j] = (int8_t)v;
239  }
240 
241  /* Compute block sums (16 elements each) */
242  for (int j = 0; j < QK_K / 16; ++j) {
243  int sum = 0;
244  const int8_t *qs = &q8_buffer[i].qs[j * 16];
245  for (int k = 0; k < 16; ++k) {
246  sum += qs[k];
247  }
248  q8_buffer[i].bsums[j] = (int16_t)sum;
249  }
250  }
251 
252  /* ================================================================
253  * PHASE 3: GEMV with Q4_K weights and Q8_K activations
254  * Q8_K data is in stack (L1/L2), not DRAM
255  * ================================================================ */
256 
257  gemv_q4_k_q8_k(y, W_q4k, q8_buffer, M, K);
258 }
#define QK_K
static int ck_nearest_int_fused(float fval)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
static float hsum256_ps_fused(__m256 v)
int8_t qs[256]
int16_t bsums[256/16]

References block_q8_K::bsums, ck_nearest_int_fused(), block_q8_K::d, gemv_q4_k_q8_k(), hsum256_ps_fused(), QK_K, and block_q8_K::qs.

◆ gemv_q4_k_q8_k()

void gemv_q4_k_q8_k ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 239 of file gemm_kernels_q4k_q8k.c.

243 {
244 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
245  /* VNNI: Best for decode (single token) - INT8 dot product acceleration */
246  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
247 #elif defined(__AVX2__)
248  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
249 #elif defined(__AVX__)
250  /* AVX version uses maddubs_epi16 (more efficient than SSE) */
251  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
252 #elif defined(__SSE4_1__)
253  gemv_q4_k_q8_k_sse(y, W, x_q8, M, K);
254 #else
255  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
256 #endif
257 }
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)

Referenced by fused_rmsnorm_linear_q4k(), gemm_q4_k_q8_k(), and unfused_rmsnorm_linear_q4k_ref().

◆ unfused_rmsnorm_linear_q4k_ref()

void unfused_rmsnorm_linear_q4k_ref ( float *  y,
const float *  x,
const float *  gamma,
const void *  W_q4k,
int  M,
int  K,
float  eps 
)

Reference (unfused) implementation for correctness testing.

This is the SLOW version that does separate RMSNorm and GEMV calls, with intermediate results going to DRAM.

Definition at line 266 of file fused_rmsnorm_linear.c.

272 {
273  if (!y || !x || !gamma || !W_q4k || M <= 0 || K <= 0) {
274  return;
275  }
276 
277  assert(K % QK_K == 0);
278 
279  /* Stack-allocated buffers (no malloc!) - stays in L1/L2 cache */
280  /* Max supported: K=4096 (16KB), 16 blocks (~5KB) */
281  if (K > 4096) return;
282 
283  float norm_out[4096];
284  block_q8_K q8_buffer[16]; /* 16 blocks for K=4096, K/QK_K */
285 
286  /* Step 1: RMSNorm (stays in cache via stack buffer) */
287  double sum_sq = 0.0;
288  for (int d = 0; d < K; ++d) {
289  sum_sq += (double)x[d] * (double)x[d];
290  }
291  float rstd = 1.0f / sqrtf((float)(sum_sq / K) + eps);
292 
293  for (int d = 0; d < K; ++d) {
294  norm_out[d] = x[d] * rstd * gamma[d]; /* DRAM WRITE */
295  }
296 
297  /* Step 2: Quantize (reads DRAM, writes DRAM) */
298  extern void quantize_row_q8_k(const float *x, void *vy, int k);
299  quantize_row_q8_k(norm_out, q8_buffer, K); /* DRAM READ + WRITE */
300 
301  /* Step 3: GEMV (reads Q8_K from cache) */
302  gemv_q4_k_q8_k(y, W_q4k, q8_buffer, M, K);
303 
304  /* No free needed - stack buffers auto-deallocate */
305 }
void quantize_row_q8_k(const float *x, void *y, int k)

References gemv_q4_k_q8_k(), QK_K, and quantize_row_q8_k().