← Back to C-Kernel-Engine Docs Doxygen Source Documentation
fused_rmsnorm_linear.c
Go to the documentation of this file.
1 /**
2  * @file fused_rmsnorm_linear.c
3  * @brief Fused RMSNorm + Linear (GEMV) kernel
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. NO memcpy for layout - use strided access, not copies
10  * 4. API must define: inputs, outputs, workspace, and memory layouts
11  * 5. Pure computation - deterministic, no side effects
12  *
13  * After changes: make test && make llamacpp-parity-full
14  *
15  * VIOLATION: Has free() calls and memcpy in test/benchmark code at end of file.
16  * TODO: Move test code to unittest/, remove free()/memcpy from kernel file.
17  *
18  * FUSION BENEFIT:
19  * ===============
20  * Unfused:
21  * RMSNorm(x) → [DRAM write: norm_out] → Quantize → [DRAM write: q8] → GEMV
22  * Total DRAM: 2 writes + 2 reads = 4 * hidden_size bytes
23  *
24  * Fused:
25  * RMSNorm(x) → [registers] → Quantize → [stack/L1: q8] → GEMV
26  * Total DRAM: 0 intermediate writes/reads
27  *
28  * Expected: 2-4x memory traffic reduction for this operation
29  */
30 
31 #include <assert.h>
32 #include <math.h>
33 #include <stddef.h>
34 #include <stdint.h>
35 #include <string.h>
36 
37 #include "ckernel_quant.h"
38 
39 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
40 #include <immintrin.h>
41 #endif
42 
43 /* Forward declarations */
44 void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K);
45 
46 /* Inline quantization helper - same as quantize_row_q8_k but operates on
47  * normalized values that may still be in registers/cache */
48 static inline int ck_nearest_int_fused(float fval) {
49  float val = fval + 12582912.f;
50  int i;
51  memcpy(&i, &val, sizeof(int));
52  return (i & 0x007fffff) - 0x00400000;
53 }
54 
55 #if defined(__AVX__) && !defined(__AVX512F__)
56 static inline float hsum256_ps_fused(__m256 v) {
57  __m128 hi = _mm256_extractf128_ps(v, 1);
58  __m128 lo = _mm256_castps256_ps128(v);
59  __m128 sum128 = _mm_add_ps(lo, hi);
60  sum128 = _mm_hadd_ps(sum128, sum128);
61  sum128 = _mm_hadd_ps(sum128, sum128);
62  return _mm_cvtss_f32(sum128);
63 }
64 #endif
65 
66 /**
67  * @brief Fused RMSNorm + Q4_K Linear projection
68  *
69  * Computes: y = Linear(RMSNorm(x))
70  * where Linear uses Q4_K weights and Q8_K activations internally.
71  *
72  * The key optimization is that the normalized values never touch DRAM -
73  * they go directly from RMSNorm computation to Q8_K quantization to GEMV.
74  *
75  * @param y Output (FP32), shape [M]
76  * @param x Input hidden state (FP32), shape [K]
77  * @param gamma RMSNorm scale weights (FP32), shape [K]
78  * @param W_q4k Linear weights in Q4_K format, shape [M, K]
79  * @param M Output dimension (e.g., 3 * hidden for QKV)
80  * @param K Input dimension (hidden_size)
81  * @param eps RMSNorm epsilon (typically 1e-5 or 1e-6)
82  */
84  const float *x,
85  const float *gamma,
86  const void *W_q4k,
87  int M, int K,
88  float eps)
89 {
90  if (!y || !x || !gamma || !W_q4k || M <= 0 || K <= 0) {
91  return;
92  }
93 
94  assert(K % QK_K == 0);
95  const int nb = K / QK_K; /* Number of Q8_K blocks */
96 
97  /* Stack-allocated Q8_K buffer - stays in L1/L2 cache */
98  /* Max supported K = 8192 (8 blocks of 256) */
99  block_q8_K q8_buffer[32]; /* 32 * ~260 bytes = ~8KB on stack */
100  assert(nb <= 32 && "K too large for stack buffer");
101 
102  /* ================================================================
103  * PHASE 1: Compute RMSNorm and quantize to Q8_K
104  * Result stays in stack (L1/L2), never touches DRAM
105  * ================================================================ */
106 
107 #if defined(__AVX512F__)
108  /* AVX-512: Compute sum of squares */
109  __m512 sum_sq_vec = _mm512_setzero_ps();
110  int d = 0;
111  for (; d + 16 <= K; d += 16) {
112  __m512 xv = _mm512_loadu_ps(&x[d]);
113  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
114  }
115  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
116  for (; d < K; ++d) {
117  sum_sq += x[d] * x[d];
118  }
119 
120 #elif defined(__AVX__)
121  /* AVX: Compute sum of squares */
122  __m256 sum_sq_vec = _mm256_setzero_ps();
123  int d = 0;
124  for (; d + 8 <= K; d += 8) {
125  __m256 xv = _mm256_loadu_ps(&x[d]);
126  __m256 xv_sq = _mm256_mul_ps(xv, xv);
127  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
128  }
129  float sum_sq = hsum256_ps_fused(sum_sq_vec);
130  for (; d < K; ++d) {
131  sum_sq += x[d] * x[d];
132  }
133 
134 #else
135  /* Scalar fallback */
136  double sum_sq = 0.0;
137  for (int d = 0; d < K; ++d) {
138  double v = (double)x[d];
139  sum_sq += v * v;
140  }
141 #endif
142 
143  float mean_sq = (float)sum_sq / (float)K;
144  float rstd = 1.0f / sqrtf(mean_sq + eps);
145 
146  /* ================================================================
147  * PHASE 2: Apply RMSNorm and quantize to Q8_K in one pass
148  * Normalized values go directly to Q8_K blocks
149  * ================================================================ */
150 
151  for (int i = 0; i < nb; ++i) {
152  const float *x_block = x + i * QK_K;
153  const float *g_block = gamma + i * QK_K;
154 
155  /* Find max absolute value for this block's normalized output */
156  float max_val = 0.0f;
157  float amax = 0.0f;
158 
159 #if defined(__AVX512F__)
160  __m512 rstd_vec = _mm512_set1_ps(rstd);
161  __m512 max_vec = _mm512_setzero_ps();
162  __m512 sign_mask = _mm512_set1_ps(-0.0f);
163 
164  for (int j = 0; j < QK_K; j += 16) {
165  __m512 xv = _mm512_loadu_ps(&x_block[j]);
166  __m512 gv = _mm512_loadu_ps(&g_block[j]);
167  __m512 norm = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
168  __m512 abs_norm = _mm512_andnot_ps(sign_mask, norm);
169  max_vec = _mm512_max_ps(max_vec, abs_norm);
170 
171  /* Track max with sign for scale computation */
172  __mmask16 gt_mask = _mm512_cmp_ps_mask(abs_norm, _mm512_set1_ps(amax), _CMP_GT_OQ);
173  if (gt_mask) {
174  float temp_amax = _mm512_reduce_max_ps(abs_norm);
175  if (temp_amax > amax) {
176  amax = temp_amax;
177  /* Find the actual max value with sign */
178  for (int k = 0; k < 16; ++k) {
179  float v = x_block[j + k] * rstd * g_block[j + k];
180  if (fabsf(v) >= amax - 1e-6f) {
181  max_val = v;
182  break;
183  }
184  }
185  }
186  }
187  }
188  amax = _mm512_reduce_max_ps(max_vec);
189 
190 #elif defined(__AVX__)
191  __m256 rstd_vec = _mm256_set1_ps(rstd);
192 
193  for (int j = 0; j < QK_K; j += 8) {
194  __m256 xv = _mm256_loadu_ps(&x_block[j]);
195  __m256 gv = _mm256_loadu_ps(&g_block[j]);
196  __m256 norm = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
197 
198  /* Check each element for max */
199  float norm_arr[8];
200  _mm256_storeu_ps(norm_arr, norm);
201  for (int k = 0; k < 8; ++k) {
202  float av = fabsf(norm_arr[k]);
203  if (av > amax) {
204  amax = av;
205  max_val = norm_arr[k];
206  }
207  }
208  }
209 
210 #else
211  for (int j = 0; j < QK_K; ++j) {
212  float norm = x_block[j] * rstd * g_block[j];
213  float av = fabsf(norm);
214  if (av > amax) {
215  amax = av;
216  max_val = norm;
217  }
218  }
219 #endif
220 
221  /* Handle zero block */
222  if (amax < 1e-10f) {
223  q8_buffer[i].d = 0.0f;
224  memset(q8_buffer[i].qs, 0, sizeof(q8_buffer[i].qs));
225  memset(q8_buffer[i].bsums, 0, sizeof(q8_buffer[i].bsums));
226  continue;
227  }
228 
229  /* Compute scale and quantize */
230  const float iscale = -127.0f / max_val;
231  q8_buffer[i].d = 1.0f / iscale;
232 
233  /* Quantize and compute bsums */
234  for (int j = 0; j < QK_K; ++j) {
235  float norm = x_block[j] * rstd * g_block[j];
236  int v = ck_nearest_int_fused(iscale * norm);
237  v = (v > 127) ? 127 : ((v < -128) ? -128 : v);
238  q8_buffer[i].qs[j] = (int8_t)v;
239  }
240 
241  /* Compute block sums (16 elements each) */
242  for (int j = 0; j < QK_K / 16; ++j) {
243  int sum = 0;
244  const int8_t *qs = &q8_buffer[i].qs[j * 16];
245  for (int k = 0; k < 16; ++k) {
246  sum += qs[k];
247  }
248  q8_buffer[i].bsums[j] = (int16_t)sum;
249  }
250  }
251 
252  /* ================================================================
253  * PHASE 3: GEMV with Q4_K weights and Q8_K activations
254  * Q8_K data is in stack (L1/L2), not DRAM
255  * ================================================================ */
256 
257  gemv_q4_k_q8_k(y, W_q4k, q8_buffer, M, K);
258 }
259 
260 /**
261  * @brief Reference (unfused) implementation for correctness testing
262  *
263  * This is the SLOW version that does separate RMSNorm and GEMV calls,
264  * with intermediate results going to DRAM.
265  */
267  const float *x,
268  const float *gamma,
269  const void *W_q4k,
270  int M, int K,
271  float eps)
272 {
273  if (!y || !x || !gamma || !W_q4k || M <= 0 || K <= 0) {
274  return;
275  }
276 
277  assert(K % QK_K == 0);
278 
279  /* Stack-allocated buffers (no malloc!) - stays in L1/L2 cache */
280  /* Max supported: K=4096 (16KB), 16 blocks (~5KB) */
281  if (K > 4096) return;
282 
283  float norm_out[4096];
284  block_q8_K q8_buffer[16]; /* 16 blocks for K=4096, K/QK_K */
285 
286  /* Step 1: RMSNorm (stays in cache via stack buffer) */
287  double sum_sq = 0.0;
288  for (int d = 0; d < K; ++d) {
289  sum_sq += (double)x[d] * (double)x[d];
290  }
291  float rstd = 1.0f / sqrtf((float)(sum_sq / K) + eps);
292 
293  for (int d = 0; d < K; ++d) {
294  norm_out[d] = x[d] * rstd * gamma[d]; /* DRAM WRITE */
295  }
296 
297  /* Step 2: Quantize (reads DRAM, writes DRAM) */
298  extern void quantize_row_q8_k(const float *x, void *vy, int k);
299  quantize_row_q8_k(norm_out, q8_buffer, K); /* DRAM READ + WRITE */
300 
301  /* Step 3: GEMV (reads Q8_K from cache) */
302  gemv_q4_k_q8_k(y, W_q4k, q8_buffer, M, K);
303 
304  /* No free needed - stack buffers auto-deallocate */
305 }
306 
307 #ifdef FUSED_KERNEL_TEST
308 /* Simple correctness test */
309 #include <stdio.h>
310 #include <stdlib.h>
311 #include <time.h>
312 
313 int main(void) {
314  const int K = 512; /* Hidden size */
315  const int M = 1536; /* Output size (3 * hidden for QKV) */
316  const int nb = K / QK_K;
317 
318  printf("Fused RMSNorm+Linear Test\n");
319  printf("K=%d, M=%d, blocks=%d\n", K, M, nb);
320 
321  /* Allocate test data */
322  float *x = (float *)aligned_alloc(64, K * sizeof(float));
323  float *gamma = (float *)aligned_alloc(64, K * sizeof(float));
324  float *y_fused = (float *)aligned_alloc(64, M * sizeof(float));
325  float *y_unfused = (float *)aligned_alloc(64, M * sizeof(float));
326 
327  /* Initialize with random data */
328  srand(42);
329  for (int i = 0; i < K; ++i) {
330  x[i] = (float)rand() / RAND_MAX * 2.0f - 1.0f;
331  gamma[i] = (float)rand() / RAND_MAX * 0.5f + 0.75f;
332  }
333 
334  /* Create dummy Q4_K weights (in real usage, these come from model) */
335  block_q4_K *W = (block_q4_K *)aligned_alloc(64, M * nb * sizeof(block_q4_K));
336  memset(W, 0, M * nb * sizeof(block_q4_K));
337  for (int i = 0; i < M * nb; ++i) {
338  W[i].d = 0x3C00; /* 1.0 in FP16 */
339  W[i].dmin = 0x0000;
340  }
341 
342  /* Run both versions */
343  printf("Running fused version...\n");
344  fused_rmsnorm_linear_q4k(y_fused, x, gamma, W, M, K, 1e-5f);
345 
346  printf("Running unfused version...\n");
347  unfused_rmsnorm_linear_q4k_ref(y_unfused, x, gamma, W, M, K, 1e-5f);
348 
349  /* Compare results */
350  float max_diff = 0.0f;
351  for (int i = 0; i < M; ++i) {
352  float diff = fabsf(y_fused[i] - y_unfused[i]);
353  if (diff > max_diff) max_diff = diff;
354  }
355 
356  printf("Max difference: %e\n", max_diff);
357  printf("Test %s\n", max_diff < 1e-3f ? "PASSED" : "FAILED");
358 
359  free(x);
360  free(gamma);
361  free(y_fused);
362  free(y_unfused);
363  free(W);
364 
365  return max_diff < 1e-3f ? 0 : 1;
366 }
367 #endif
int main(int argc, char **argv)
Definition: ck_cli_v5.c:110
void quantize_row_q8_k(const float *x, void *y, int k)
Quantization block structures for weight-only quantization.
#define QK_K
void unfused_rmsnorm_linear_q4k_ref(float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps)
Reference (unfused) implementation for correctness testing.
void fused_rmsnorm_linear_q4k(float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps)
Fused RMSNorm + Q4_K Linear projection.
static int ck_nearest_int_fused(float fval)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
static float hsum256_ps_fused(__m256 v)
ck_half dmin
int8_t qs[256]
int16_t bsums[256/16]