← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_qkv.c
Go to the documentation of this file.
1 /**
2  * @file rmsnorm_qkv.c
3  * @brief Fused RMSNorm + QKV Projection
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Part of C-Kernel-Engine v6.6 Fusion Kernels
15  *
16  * PROBLEM:
17  * Non-fused version does 4 DRAM round-trips for 'normed' buffer:
18  * rmsnorm(x, weight, normed); // Write normed to DRAM
19  * gemv(wq, normed, q); // Read normed from DRAM
20  * gemv(wk, normed, k); // Read normed from DRAM
21  * gemv(wv, normed, v); // Read normed from DRAM
22  *
23  * SOLUTION:
24  * Fused version keeps 'normed' in registers/L1, zero DRAM access:
25  * rmsnorm_qkv_fused(x, weight, wq, wk, wv, q, k, v);
26  *
27  * EXPECTED SPEEDUP: 1.5-2x for this operation
28  */
29 
30 #include <stdint.h>
31 #include <stddef.h>
32 #include <math.h>
33 #include <string.h>
34 
35 #ifdef __AVX2__
36 #include <immintrin.h>
37 #endif
38 
39 #include "ckernel_quant.h"
40 
41 /* ============================================================================
42  * HELPER: RMSNorm computation (inline, result stays in registers)
43  * ============================================================================ */
44 
45 static inline float compute_rms_scale(const float *x, int n, float eps) {
46  float sum_sq = 0.0f;
47 
48 #ifdef __AVX2__
49  __m256 vsum = _mm256_setzero_ps();
50  int i = 0;
51  for (; i + 7 < n; i += 8) {
52  __m256 vx = _mm256_loadu_ps(x + i);
53  vsum = _mm256_fmadd_ps(vx, vx, vsum);
54  }
55  // Horizontal sum
56  __m128 vlow = _mm256_castps256_ps128(vsum);
57  __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
58  vlow = _mm_add_ps(vlow, vhigh);
59  vlow = _mm_hadd_ps(vlow, vlow);
60  vlow = _mm_hadd_ps(vlow, vlow);
61  sum_sq = _mm_cvtss_f32(vlow);
62  // Remainder
63  for (; i < n; i++) {
64  sum_sq += x[i] * x[i];
65  }
66 #else
67  for (int i = 0; i < n; i++) {
68  sum_sq += x[i] * x[i];
69  }
70 #endif
71 
72  float rms = sqrtf(sum_sq / (float)n + eps);
73  return 1.0f / rms;
74 }
75 
76 /* ============================================================================
77  * FUSED KERNEL: RMSNorm + QKV Projection (FP32 weights)
78  * ============================================================================ */
79 
81  const float *x, /* [embed_dim] input hidden state */
82  const float *rms_weight, /* [embed_dim] RMSNorm gamma */
83  const float *wq, /* [q_dim, embed_dim] Q projection */
84  const float *wk, /* [kv_dim, embed_dim] K projection */
85  const float *wv, /* [kv_dim, embed_dim] V projection */
86  float *q_out, /* [q_dim] output Q */
87  float *k_out, /* [kv_dim] output K */
88  float *v_out, /* [kv_dim] output V */
89  int embed_dim, /* Hidden dimension */
90  int q_dim, /* Q output dimension (num_heads * head_dim) */
91  int kv_dim, /* KV output dimension (num_kv_heads * head_dim) */
92  float eps /* RMSNorm epsilon (typically 1e-6) */
93 ) {
94  /* Step 1: Compute RMS scale factor (stays in register) */
95  float scale = compute_rms_scale(x, embed_dim, eps);
96 
97  /* Step 2: Fused normalize + project
98  *
99  * Key insight: We compute normed[i] = x[i] * rms_weight[i] * scale
100  * on-the-fly during the GEMV, never storing the full normed vector.
101  *
102  * For each output element:
103  * q[j] = sum_i( wq[j,i] * x[i] * rms_weight[i] * scale )
104  * = scale * sum_i( wq[j,i] * x[i] * rms_weight[i] )
105  */
106 
107  /* Q projection */
108  for (int j = 0; j < q_dim; j++) {
109  float sum = 0.0f;
110  const float *wq_row = wq + j * embed_dim;
111 
112 #ifdef __AVX2__
113  __m256 vsum = _mm256_setzero_ps();
114  int i = 0;
115  for (; i + 7 < embed_dim; i += 8) {
116  __m256 vx = _mm256_loadu_ps(x + i);
117  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
118  __m256 vw = _mm256_loadu_ps(wq_row + i);
119  __m256 vnormed = _mm256_mul_ps(vx, vrms);
120  vsum = _mm256_fmadd_ps(vw, vnormed, vsum);
121  }
122  /* Horizontal sum */
123  __m128 vlow = _mm256_castps256_ps128(vsum);
124  __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
125  vlow = _mm_add_ps(vlow, vhigh);
126  vlow = _mm_hadd_ps(vlow, vlow);
127  vlow = _mm_hadd_ps(vlow, vlow);
128  sum = _mm_cvtss_f32(vlow);
129  /* Remainder */
130  for (; i < embed_dim; i++) {
131  sum += wq_row[i] * x[i] * rms_weight[i];
132  }
133 #else
134  for (int i = 0; i < embed_dim; i++) {
135  sum += wq_row[i] * x[i] * rms_weight[i];
136  }
137 #endif
138  q_out[j] = sum * scale;
139  }
140 
141  /* K projection */
142  for (int j = 0; j < kv_dim; j++) {
143  float sum = 0.0f;
144  const float *wk_row = wk + j * embed_dim;
145 
146 #ifdef __AVX2__
147  __m256 vsum = _mm256_setzero_ps();
148  int i = 0;
149  for (; i + 7 < embed_dim; i += 8) {
150  __m256 vx = _mm256_loadu_ps(x + i);
151  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
152  __m256 vw = _mm256_loadu_ps(wk_row + i);
153  __m256 vnormed = _mm256_mul_ps(vx, vrms);
154  vsum = _mm256_fmadd_ps(vw, vnormed, vsum);
155  }
156  __m128 vlow = _mm256_castps256_ps128(vsum);
157  __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
158  vlow = _mm_add_ps(vlow, vhigh);
159  vlow = _mm_hadd_ps(vlow, vlow);
160  vlow = _mm_hadd_ps(vlow, vlow);
161  sum = _mm_cvtss_f32(vlow);
162  for (; i < embed_dim; i++) {
163  sum += wk_row[i] * x[i] * rms_weight[i];
164  }
165 #else
166  for (int i = 0; i < embed_dim; i++) {
167  sum += wk_row[i] * x[i] * rms_weight[i];
168  }
169 #endif
170  k_out[j] = sum * scale;
171  }
172 
173  /* V projection */
174  for (int j = 0; j < kv_dim; j++) {
175  float sum = 0.0f;
176  const float *wv_row = wv + j * embed_dim;
177 
178 #ifdef __AVX2__
179  __m256 vsum = _mm256_setzero_ps();
180  int i = 0;
181  for (; i + 7 < embed_dim; i += 8) {
182  __m256 vx = _mm256_loadu_ps(x + i);
183  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
184  __m256 vw = _mm256_loadu_ps(wv_row + i);
185  __m256 vnormed = _mm256_mul_ps(vx, vrms);
186  vsum = _mm256_fmadd_ps(vw, vnormed, vsum);
187  }
188  __m128 vlow = _mm256_castps256_ps128(vsum);
189  __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
190  vlow = _mm_add_ps(vlow, vhigh);
191  vlow = _mm_hadd_ps(vlow, vlow);
192  vlow = _mm_hadd_ps(vlow, vlow);
193  sum = _mm_cvtss_f32(vlow);
194  for (; i < embed_dim; i++) {
195  sum += wv_row[i] * x[i] * rms_weight[i];
196  }
197 #else
198  for (int i = 0; i < embed_dim; i++) {
199  sum += wv_row[i] * x[i] * rms_weight[i];
200  }
201 #endif
202  v_out[j] = sum * scale;
203  }
204 }
205 
206 /* ============================================================================
207  * FUSED KERNEL: RMSNorm + QKV Projection (Q4_K quantized weights)
208  *
209  * This is the production version - weights are Q4_K quantized.
210  * We dequantize on-the-fly during the fused operation.
211  * ============================================================================ */
212 
214  const float *x, /* [embed_dim] input hidden state */
215  const float *rms_weight, /* [embed_dim] RMSNorm gamma */
216  const void *wq, /* Q4_K quantized Q projection */
217  const void *wk, /* Q4_K quantized K projection */
218  const void *wv, /* Q4_K quantized V projection */
219  float *q_out, /* [q_dim] output Q */
220  float *k_out, /* [kv_dim] output K */
221  float *v_out, /* [kv_dim] output V */
222  int embed_dim, /* Hidden dimension */
223  int q_dim, /* Q output dimension */
224  int kv_dim, /* KV output dimension */
225  float eps /* RMSNorm epsilon */
226 ) {
227  /* Step 1: Compute RMS scale */
228  float scale = compute_rms_scale(x, embed_dim, eps);
229 
230  /* Step 2: Compute normalized input (we need this for Q4_K dequant fusion)
231  *
232  * For Q4_K, we can't easily fuse the normalization into the dequant loop
233  * because the block structure is complex. So we compute normed[] first,
234  * but keep it small enough to fit in L1 cache.
235  *
236  * TODO: For maximum performance, implement a true fused Q4_K GEMV
237  * that dequantizes and multiplies by normed[i] in the same loop.
238  */
239 
240  /* Allocate normed on stack (fits in L1 for typical embed_dim <= 4096) */
241  float normed[4096]; /* Max supported embed_dim */
242  if (embed_dim > 4096) {
243  /* Fallback for very large models */
244  return; /* TODO: heap allocation */
245  }
246 
247  /* Compute normed = x * rms_weight * scale */
248 #ifdef __AVX2__
249  __m256 vscale = _mm256_set1_ps(scale);
250  int i = 0;
251  for (; i + 7 < embed_dim; i += 8) {
252  __m256 vx = _mm256_loadu_ps(x + i);
253  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
254  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
255  _mm256_storeu_ps(normed + i, vn);
256  }
257  for (; i < embed_dim; i++) {
258  normed[i] = x[i] * rms_weight[i] * scale;
259  }
260 #else
261  for (int i = 0; i < embed_dim; i++) {
262  normed[i] = x[i] * rms_weight[i] * scale;
263  }
264 #endif
265 
266  /* Step 3: Q4_K GEMV with normed input
267  *
268  * Call existing Q4_K GEMV kernel with normed[] as input.
269  * The normed[] buffer is in L1 cache, so this is still fast.
270  *
271  * Key insight: normed[] never leaves L1 cache because we use it
272  * immediately in the next 3 GEMVs. This eliminates the DRAM write
273  * that would happen in the non-fused version.
274  */
275 
276  /* Declare external GEMV function */
277  extern void gemv_q4_k(float *y, const void *W, const float *x, int M, int K);
278 
279  /* Q projection: q_out[q_dim] = wq[q_dim, embed_dim] @ normed[embed_dim] */
280  gemv_q4_k(q_out, wq, normed, q_dim, embed_dim);
281 
282  /* K projection: k_out[kv_dim] = wk[kv_dim, embed_dim] @ normed[embed_dim] */
283  gemv_q4_k(k_out, wk, normed, kv_dim, embed_dim);
284 
285  /* V projection: v_out[kv_dim] = wv[kv_dim, embed_dim] @ normed[embed_dim] */
286  gemv_q4_k(v_out, wv, normed, kv_dim, embed_dim);
287 }
288 
289 /* ============================================================================
290  * TRUE SIMD FUSION: RMSNorm + QKV (FP32 weights) - VARIATION 2
291  *
292  * KEY INSIGHT: Process OUTPUT cache-line by cache-line.
293  * For each output cache line:
294  * - Keep accumulators in YMM/ZMM registers
295  * - For each INPUT cache line:
296  * - Compute normed chunk IN REGISTER (never stored to memory!)
297  * - Use immediately for all output accumulators via FMADD
298  * - Only store when output cache line is complete
299  *
300  * This is TRUE register-level fusion:
301  * - normed[] NEVER touches L1 cache
302  * - Each input cache line loaded ONCE, used for multiple outputs
303  * - Memory traffic: input + weights + output (no intermediate!)
304  *
305  * Expected speedup: 1.5-2x over separate kernels
306  * ============================================================================ */
307 
308 #ifdef __AVX2__
309 static inline float hsum256_ps(__m256 v) {
310  __m128 vlow = _mm256_castps256_ps128(v);
311  __m128 vhigh = _mm256_extractf128_ps(v, 1);
312  vlow = _mm_add_ps(vlow, vhigh);
313  __m128 shuf = _mm_movehdup_ps(vlow);
314  vlow = _mm_add_ps(vlow, shuf);
315  shuf = _mm_movehl_ps(shuf, vlow);
316  vlow = _mm_add_ss(vlow, shuf);
317  return _mm_cvtss_f32(vlow);
318 }
319 #endif
320 
322  const float *x, /* [embed_dim] input hidden state */
323  const float *rms_weight, /* [embed_dim] RMSNorm gamma */
324  const float *wq, /* [q_dim, embed_dim] Q projection (row-major) */
325  const float *wk, /* [kv_dim, embed_dim] K projection */
326  const float *wv, /* [kv_dim, embed_dim] V projection */
327  float *q_out, /* [q_dim] output Q */
328  float *k_out, /* [kv_dim] output K */
329  float *v_out, /* [kv_dim] output V */
330  int embed_dim, /* Hidden dimension */
331  int q_dim, /* Q output dimension (num_heads * head_dim) */
332  int kv_dim, /* KV output dimension (num_kv_heads * head_dim) */
333  float eps /* RMSNorm epsilon (typically 1e-6) */
334 ) {
335  /* Step 1: Compute RMS scale (requires full pass - unavoidable) */
336  float scale = compute_rms_scale(x, embed_dim, eps);
337 
338 #ifdef __AVX2__
339  __m256 vscale = _mm256_set1_ps(scale);
340 
341  /* ═══════════════════════════════════════════════════════════════════════
342  * Q PROJECTION: Process 8 outputs at a time (one cache line)
343  *
344  * For each output cache line [j:j+8]:
345  * acc[0..7] = 0
346  * For each input cache line [i:i+8]:
347  * normed = x[i:i+8] * rms_weight[i:i+8] * scale ← IN REGISTER!
348  * acc[k] += W[j+k, i:i+8] · normed ← FMADD
349  * Store q_out[j:j+8]
350  * ═══════════════════════════════════════════════════════════════════════ */
351 
352  for (int j = 0; j < q_dim; j += 8) {
353  /* 8 accumulators for 8 output elements - all in YMM registers */
354  __m256 acc0 = _mm256_setzero_ps();
355  __m256 acc1 = _mm256_setzero_ps();
356  __m256 acc2 = _mm256_setzero_ps();
357  __m256 acc3 = _mm256_setzero_ps();
358  __m256 acc4 = _mm256_setzero_ps();
359  __m256 acc5 = _mm256_setzero_ps();
360  __m256 acc6 = _mm256_setzero_ps();
361  __m256 acc7 = _mm256_setzero_ps();
362 
363  /* Process input in cache-line chunks */
364  int i = 0;
365  for (; i + 7 < embed_dim; i += 8) {
366  /* Load input cache line and normalize IN REGISTER */
367  __m256 vx = _mm256_loadu_ps(x + i);
368  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
369  __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
370  /* normed is now in YMM register - NEVER touches memory! */
371 
372  /* Load 8 weight rows and accumulate */
373  /* Each row is at wq[(j+k)*embed_dim + i] */
374  if (j + 0 < q_dim) {
375  __m256 w0 = _mm256_loadu_ps(wq + (j+0)*embed_dim + i);
376  acc0 = _mm256_fmadd_ps(w0, normed, acc0);
377  }
378  if (j + 1 < q_dim) {
379  __m256 w1 = _mm256_loadu_ps(wq + (j+1)*embed_dim + i);
380  acc1 = _mm256_fmadd_ps(w1, normed, acc1);
381  }
382  if (j + 2 < q_dim) {
383  __m256 w2 = _mm256_loadu_ps(wq + (j+2)*embed_dim + i);
384  acc2 = _mm256_fmadd_ps(w2, normed, acc2);
385  }
386  if (j + 3 < q_dim) {
387  __m256 w3 = _mm256_loadu_ps(wq + (j+3)*embed_dim + i);
388  acc3 = _mm256_fmadd_ps(w3, normed, acc3);
389  }
390  if (j + 4 < q_dim) {
391  __m256 w4 = _mm256_loadu_ps(wq + (j+4)*embed_dim + i);
392  acc4 = _mm256_fmadd_ps(w4, normed, acc4);
393  }
394  if (j + 5 < q_dim) {
395  __m256 w5 = _mm256_loadu_ps(wq + (j+5)*embed_dim + i);
396  acc5 = _mm256_fmadd_ps(w5, normed, acc5);
397  }
398  if (j + 6 < q_dim) {
399  __m256 w6 = _mm256_loadu_ps(wq + (j+6)*embed_dim + i);
400  acc6 = _mm256_fmadd_ps(w6, normed, acc6);
401  }
402  if (j + 7 < q_dim) {
403  __m256 w7 = _mm256_loadu_ps(wq + (j+7)*embed_dim + i);
404  acc7 = _mm256_fmadd_ps(w7, normed, acc7);
405  }
406  }
407 
408  /* Handle remainder (scalar, rare for aligned dims) */
409  for (; i < embed_dim; i++) {
410  float normed_scalar = x[i] * rms_weight[i] * scale;
411  if (j + 0 < q_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wq[(j+0)*embed_dim + i] * normed_scalar));
412  if (j + 1 < q_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wq[(j+1)*embed_dim + i] * normed_scalar));
413  if (j + 2 < q_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wq[(j+2)*embed_dim + i] * normed_scalar));
414  if (j + 3 < q_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wq[(j+3)*embed_dim + i] * normed_scalar));
415  if (j + 4 < q_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wq[(j+4)*embed_dim + i] * normed_scalar));
416  if (j + 5 < q_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wq[(j+5)*embed_dim + i] * normed_scalar));
417  if (j + 6 < q_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wq[(j+6)*embed_dim + i] * normed_scalar));
418  if (j + 7 < q_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wq[(j+7)*embed_dim + i] * normed_scalar));
419  }
420 
421  /* Horizontal sum and store output cache line */
422  if (j + 0 < q_dim) q_out[j+0] = hsum256_ps(acc0);
423  if (j + 1 < q_dim) q_out[j+1] = hsum256_ps(acc1);
424  if (j + 2 < q_dim) q_out[j+2] = hsum256_ps(acc2);
425  if (j + 3 < q_dim) q_out[j+3] = hsum256_ps(acc3);
426  if (j + 4 < q_dim) q_out[j+4] = hsum256_ps(acc4);
427  if (j + 5 < q_dim) q_out[j+5] = hsum256_ps(acc5);
428  if (j + 6 < q_dim) q_out[j+6] = hsum256_ps(acc6);
429  if (j + 7 < q_dim) q_out[j+7] = hsum256_ps(acc7);
430  }
431 
432  /* ═══════════════════════════════════════════════════════════════════════
433  * K PROJECTION: Same pattern, smaller output
434  * ═══════════════════════════════════════════════════════════════════════ */
435 
436  for (int j = 0; j < kv_dim; j += 8) {
437  __m256 acc0 = _mm256_setzero_ps();
438  __m256 acc1 = _mm256_setzero_ps();
439  __m256 acc2 = _mm256_setzero_ps();
440  __m256 acc3 = _mm256_setzero_ps();
441  __m256 acc4 = _mm256_setzero_ps();
442  __m256 acc5 = _mm256_setzero_ps();
443  __m256 acc6 = _mm256_setzero_ps();
444  __m256 acc7 = _mm256_setzero_ps();
445 
446  int i = 0;
447  for (; i + 7 < embed_dim; i += 8) {
448  __m256 vx = _mm256_loadu_ps(x + i);
449  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
450  __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
451 
452  if (j + 0 < kv_dim) acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+0)*embed_dim + i), normed, acc0);
453  if (j + 1 < kv_dim) acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+1)*embed_dim + i), normed, acc1);
454  if (j + 2 < kv_dim) acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+2)*embed_dim + i), normed, acc2);
455  if (j + 3 < kv_dim) acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+3)*embed_dim + i), normed, acc3);
456  if (j + 4 < kv_dim) acc4 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+4)*embed_dim + i), normed, acc4);
457  if (j + 5 < kv_dim) acc5 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+5)*embed_dim + i), normed, acc5);
458  if (j + 6 < kv_dim) acc6 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+6)*embed_dim + i), normed, acc6);
459  if (j + 7 < kv_dim) acc7 = _mm256_fmadd_ps(_mm256_loadu_ps(wk + (j+7)*embed_dim + i), normed, acc7);
460  }
461 
462  for (; i < embed_dim; i++) {
463  float normed_scalar = x[i] * rms_weight[i] * scale;
464  if (j + 0 < kv_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wk[(j+0)*embed_dim + i] * normed_scalar));
465  if (j + 1 < kv_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wk[(j+1)*embed_dim + i] * normed_scalar));
466  if (j + 2 < kv_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wk[(j+2)*embed_dim + i] * normed_scalar));
467  if (j + 3 < kv_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wk[(j+3)*embed_dim + i] * normed_scalar));
468  if (j + 4 < kv_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wk[(j+4)*embed_dim + i] * normed_scalar));
469  if (j + 5 < kv_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wk[(j+5)*embed_dim + i] * normed_scalar));
470  if (j + 6 < kv_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wk[(j+6)*embed_dim + i] * normed_scalar));
471  if (j + 7 < kv_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wk[(j+7)*embed_dim + i] * normed_scalar));
472  }
473 
474  if (j + 0 < kv_dim) k_out[j+0] = hsum256_ps(acc0);
475  if (j + 1 < kv_dim) k_out[j+1] = hsum256_ps(acc1);
476  if (j + 2 < kv_dim) k_out[j+2] = hsum256_ps(acc2);
477  if (j + 3 < kv_dim) k_out[j+3] = hsum256_ps(acc3);
478  if (j + 4 < kv_dim) k_out[j+4] = hsum256_ps(acc4);
479  if (j + 5 < kv_dim) k_out[j+5] = hsum256_ps(acc5);
480  if (j + 6 < kv_dim) k_out[j+6] = hsum256_ps(acc6);
481  if (j + 7 < kv_dim) k_out[j+7] = hsum256_ps(acc7);
482  }
483 
484  /* ═══════════════════════════════════════════════════════════════════════
485  * V PROJECTION: Same pattern
486  * ═══════════════════════════════════════════════════════════════════════ */
487 
488  for (int j = 0; j < kv_dim; j += 8) {
489  __m256 acc0 = _mm256_setzero_ps();
490  __m256 acc1 = _mm256_setzero_ps();
491  __m256 acc2 = _mm256_setzero_ps();
492  __m256 acc3 = _mm256_setzero_ps();
493  __m256 acc4 = _mm256_setzero_ps();
494  __m256 acc5 = _mm256_setzero_ps();
495  __m256 acc6 = _mm256_setzero_ps();
496  __m256 acc7 = _mm256_setzero_ps();
497 
498  int i = 0;
499  for (; i + 7 < embed_dim; i += 8) {
500  __m256 vx = _mm256_loadu_ps(x + i);
501  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
502  __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
503 
504  if (j + 0 < kv_dim) acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+0)*embed_dim + i), normed, acc0);
505  if (j + 1 < kv_dim) acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+1)*embed_dim + i), normed, acc1);
506  if (j + 2 < kv_dim) acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+2)*embed_dim + i), normed, acc2);
507  if (j + 3 < kv_dim) acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+3)*embed_dim + i), normed, acc3);
508  if (j + 4 < kv_dim) acc4 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+4)*embed_dim + i), normed, acc4);
509  if (j + 5 < kv_dim) acc5 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+5)*embed_dim + i), normed, acc5);
510  if (j + 6 < kv_dim) acc6 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+6)*embed_dim + i), normed, acc6);
511  if (j + 7 < kv_dim) acc7 = _mm256_fmadd_ps(_mm256_loadu_ps(wv + (j+7)*embed_dim + i), normed, acc7);
512  }
513 
514  for (; i < embed_dim; i++) {
515  float normed_scalar = x[i] * rms_weight[i] * scale;
516  if (j + 0 < kv_dim) acc0 = _mm256_add_ps(acc0, _mm256_set1_ps(wv[(j+0)*embed_dim + i] * normed_scalar));
517  if (j + 1 < kv_dim) acc1 = _mm256_add_ps(acc1, _mm256_set1_ps(wv[(j+1)*embed_dim + i] * normed_scalar));
518  if (j + 2 < kv_dim) acc2 = _mm256_add_ps(acc2, _mm256_set1_ps(wv[(j+2)*embed_dim + i] * normed_scalar));
519  if (j + 3 < kv_dim) acc3 = _mm256_add_ps(acc3, _mm256_set1_ps(wv[(j+3)*embed_dim + i] * normed_scalar));
520  if (j + 4 < kv_dim) acc4 = _mm256_add_ps(acc4, _mm256_set1_ps(wv[(j+4)*embed_dim + i] * normed_scalar));
521  if (j + 5 < kv_dim) acc5 = _mm256_add_ps(acc5, _mm256_set1_ps(wv[(j+5)*embed_dim + i] * normed_scalar));
522  if (j + 6 < kv_dim) acc6 = _mm256_add_ps(acc6, _mm256_set1_ps(wv[(j+6)*embed_dim + i] * normed_scalar));
523  if (j + 7 < kv_dim) acc7 = _mm256_add_ps(acc7, _mm256_set1_ps(wv[(j+7)*embed_dim + i] * normed_scalar));
524  }
525 
526  if (j + 0 < kv_dim) v_out[j+0] = hsum256_ps(acc0);
527  if (j + 1 < kv_dim) v_out[j+1] = hsum256_ps(acc1);
528  if (j + 2 < kv_dim) v_out[j+2] = hsum256_ps(acc2);
529  if (j + 3 < kv_dim) v_out[j+3] = hsum256_ps(acc3);
530  if (j + 4 < kv_dim) v_out[j+4] = hsum256_ps(acc4);
531  if (j + 5 < kv_dim) v_out[j+5] = hsum256_ps(acc5);
532  if (j + 6 < kv_dim) v_out[j+6] = hsum256_ps(acc6);
533  if (j + 7 < kv_dim) v_out[j+7] = hsum256_ps(acc7);
534  }
535 
536 #else
537  /* Scalar fallback - same logic, no SIMD */
538  for (int j = 0; j < q_dim; j++) {
539  float sum = 0.0f;
540  for (int i = 0; i < embed_dim; i++) {
541  float normed = x[i] * rms_weight[i] * scale;
542  sum += wq[j * embed_dim + i] * normed;
543  }
544  q_out[j] = sum;
545  }
546  for (int j = 0; j < kv_dim; j++) {
547  float sum = 0.0f;
548  for (int i = 0; i < embed_dim; i++) {
549  float normed = x[i] * rms_weight[i] * scale;
550  sum += wk[j * embed_dim + i] * normed;
551  }
552  k_out[j] = sum;
553  }
554  for (int j = 0; j < kv_dim; j++) {
555  float sum = 0.0f;
556  for (int i = 0; i < embed_dim; i++) {
557  float normed = x[i] * rms_weight[i] * scale;
558  sum += wv[j * embed_dim + i] * normed;
559  }
560  v_out[j] = sum;
561  }
562 #endif
563 }
564 
565 /* ============================================================================
566  * TRUE SIMD FUSION V3: RMSNorm + QKV (FP32 weights)
567  *
568  * KEY FIX: Process Q, K, V SIMULTANEOUSLY in one pass through input!
569  *
570  * Previous versions had a flaw:
571  * v1: Recomputes normed for each Q row, then each K row, then each V row
572  * = 3 * q_dim * embed_dim FMA operations (tripled work!)
573  * v2: Same issue, just with 8-at-a-time grouping
574  *
575  * v3 approach:
576  * For each OUTPUT index j (0 to max(q_dim, kv_dim)):
577  * q_acc = k_acc = v_acc = 0
578  * For each INPUT chunk [i:i+8]:
579  * normed = x[i:i+8] * rms_weight[i:i+8] * scale (computed ONCE!)
580  * q_acc += wq[j,i:i+8] · normed
581  * k_acc += wk[j,i:i+8] · normed (if j < kv_dim)
582  * v_acc += wv[j,i:i+8] · normed (if j < kv_dim)
583  * Store q_out[j], k_out[j], v_out[j]
584  *
585  * Benefits:
586  * - normed computed ONCE per input chunk, used 3x
587  * - Sequential weight access (good prefetch)
588  * - Minimal register pressure (3 accumulators + 1 normed)
589  * ============================================================================ */
590 
592  const float *x, /* [embed_dim] input hidden state */
593  const float *rms_weight, /* [embed_dim] RMSNorm gamma */
594  const float *wq, /* [q_dim, embed_dim] Q projection (row-major) */
595  const float *wk, /* [kv_dim, embed_dim] K projection */
596  const float *wv, /* [kv_dim, embed_dim] V projection */
597  float *q_out, /* [q_dim] output Q */
598  float *k_out, /* [kv_dim] output K */
599  float *v_out, /* [kv_dim] output V */
600  int embed_dim, /* Hidden dimension */
601  int q_dim, /* Q output dimension (num_heads * head_dim) */
602  int kv_dim, /* KV output dimension (num_kv_heads * head_dim) */
603  float eps /* RMSNorm epsilon (typically 1e-6) */
604 ) {
605  /* Step 1: Compute RMS scale (requires full pass - unavoidable) */
606  float scale = compute_rms_scale(x, embed_dim, eps);
607 
608 #ifdef __AVX2__
609  __m256 vscale = _mm256_set1_ps(scale);
610 
611  /* ═══════════════════════════════════════════════════════════════════════
612  * Phase 1: Process Q outputs that have corresponding K,V outputs
613  * (j < kv_dim: compute Q, K, V together)
614  * ═══════════════════════════════════════════════════════════════════════ */
615  for (int j = 0; j < kv_dim; j++) {
616  __m256 q_acc = _mm256_setzero_ps();
617  __m256 k_acc = _mm256_setzero_ps();
618  __m256 v_acc = _mm256_setzero_ps();
619 
620  int i = 0;
621  for (; i + 7 < embed_dim; i += 8) {
622  /* Load input and normalize - computed ONCE, used THREE times! */
623  __m256 vx = _mm256_loadu_ps(x + i);
624  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
625  __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
626 
627  /* Load weight rows - sequential access per row */
628  __m256 wq_row = _mm256_loadu_ps(wq + j * embed_dim + i);
629  __m256 wk_row = _mm256_loadu_ps(wk + j * embed_dim + i);
630  __m256 wv_row = _mm256_loadu_ps(wv + j * embed_dim + i);
631 
632  /* Accumulate - normed stays in register! */
633  q_acc = _mm256_fmadd_ps(wq_row, normed, q_acc);
634  k_acc = _mm256_fmadd_ps(wk_row, normed, k_acc);
635  v_acc = _mm256_fmadd_ps(wv_row, normed, v_acc);
636  }
637 
638  /* Handle remainder (scalar) */
639  float q_sum = hsum256_ps(q_acc);
640  float k_sum = hsum256_ps(k_acc);
641  float v_sum = hsum256_ps(v_acc);
642 
643  for (; i < embed_dim; i++) {
644  float normed = x[i] * rms_weight[i] * scale;
645  q_sum += wq[j * embed_dim + i] * normed;
646  k_sum += wk[j * embed_dim + i] * normed;
647  v_sum += wv[j * embed_dim + i] * normed;
648  }
649 
650  q_out[j] = q_sum;
651  k_out[j] = k_sum;
652  v_out[j] = v_sum;
653  }
654 
655  /* ═══════════════════════════════════════════════════════════════════════
656  * Phase 2: Process remaining Q outputs (j >= kv_dim: Q only)
657  * This handles GQA where q_dim > kv_dim
658  * ═══════════════════════════════════════════════════════════════════════ */
659  for (int j = kv_dim; j < q_dim; j++) {
660  __m256 q_acc = _mm256_setzero_ps();
661 
662  int i = 0;
663  for (; i + 7 < embed_dim; i += 8) {
664  __m256 vx = _mm256_loadu_ps(x + i);
665  __m256 vrms = _mm256_loadu_ps(rms_weight + i);
666  __m256 normed = _mm256_mul_ps(_mm256_mul_ps(vx, vrms), vscale);
667 
668  __m256 wq_row = _mm256_loadu_ps(wq + j * embed_dim + i);
669  q_acc = _mm256_fmadd_ps(wq_row, normed, q_acc);
670  }
671 
672  float q_sum = hsum256_ps(q_acc);
673  for (; i < embed_dim; i++) {
674  float normed = x[i] * rms_weight[i] * scale;
675  q_sum += wq[j * embed_dim + i] * normed;
676  }
677 
678  q_out[j] = q_sum;
679  }
680 
681 #else
682  /* Scalar fallback - same simultaneous Q,K,V approach */
683  for (int j = 0; j < kv_dim; j++) {
684  float q_sum = 0.0f, k_sum = 0.0f, v_sum = 0.0f;
685  for (int i = 0; i < embed_dim; i++) {
686  float normed = x[i] * rms_weight[i] * scale;
687  q_sum += wq[j * embed_dim + i] * normed;
688  k_sum += wk[j * embed_dim + i] * normed;
689  v_sum += wv[j * embed_dim + i] * normed;
690  }
691  q_out[j] = q_sum;
692  k_out[j] = k_sum;
693  v_out[j] = v_sum;
694  }
695  for (int j = kv_dim; j < q_dim; j++) {
696  float q_sum = 0.0f;
697  for (int i = 0; i < embed_dim; i++) {
698  float normed = x[i] * rms_weight[i] * scale;
699  q_sum += wq[j * embed_dim + i] * normed;
700  }
701  q_out[j] = q_sum;
702  }
703 #endif
704 }
705 
706 /* ============================================================================
707  * NON-FUSED REFERENCE: For benchmarking comparison
708  *
709  * This is what we're comparing against. Call rmsnorm + 3x GEMV separately.
710  * ============================================================================ */
711 
713  const float *x,
714  const float *rms_weight,
715  const float *wq,
716  const float *wk,
717  const float *wv,
718  float *normed, /* [embed_dim] intermediate buffer - DRAM write! */
719  float *q_out,
720  float *k_out,
721  float *v_out,
722  int embed_dim,
723  int q_dim,
724  int kv_dim,
725  float eps
726 ) {
727  /* Step 1: RMSNorm - writes normed to DRAM */
728  float scale = compute_rms_scale(x, embed_dim, eps);
729  for (int i = 0; i < embed_dim; i++) {
730  normed[i] = x[i] * rms_weight[i] * scale;
731  }
732 
733  /* Step 2: Q projection - reads normed from DRAM */
734  for (int j = 0; j < q_dim; j++) {
735  float sum = 0.0f;
736  for (int i = 0; i < embed_dim; i++) {
737  sum += wq[j * embed_dim + i] * normed[i];
738  }
739  q_out[j] = sum;
740  }
741 
742  /* Step 3: K projection - reads normed from DRAM */
743  for (int j = 0; j < kv_dim; j++) {
744  float sum = 0.0f;
745  for (int i = 0; i < embed_dim; i++) {
746  sum += wk[j * embed_dim + i] * normed[i];
747  }
748  k_out[j] = sum;
749  }
750 
751  /* Step 4: V projection - reads normed from DRAM */
752  for (int j = 0; j < kv_dim; j++) {
753  float sum = 0.0f;
754  for (int i = 0; i < embed_dim; i++) {
755  sum += wv[j * embed_dim + i] * normed[i];
756  }
757  v_out[j] = sum;
758  }
759 }
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
Quantization block structures for weight-only quantization.
static float compute_rms_scale(const float *x, int n, float eps)
Definition: rmsnorm_qkv.c:45
void rmsnorm_qkv_fp32_fused_v3(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
Definition: rmsnorm_qkv.c:591
void rmsnorm_qkv_separate_fp32(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *normed, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
Definition: rmsnorm_qkv.c:712
void rmsnorm_qkv_fp32_fused_v2(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
Definition: rmsnorm_qkv.c:321
void rmsnorm_qkv_fp32_fused(const float *x, const float *rms_weight, const float *wq, const float *wk, const float *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
Definition: rmsnorm_qkv.c:80
void rmsnorm_qkv_q4k_fused(const float *x, const float *rms_weight, const void *wq, const void *wk, const void *wv, float *q_out, float *k_out, float *v_out, int embed_dim, int q_dim, int kv_dim, float eps)
Definition: rmsnorm_qkv.c:213