← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_mlp_fused.c
Go to the documentation of this file.
1 /**
2  * @file attention_mlp_fused.c
3  * @brief Mega-Fused Attention + MLP Block
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. NO memcpy for layout - use strided access, not copies
10  * 4. API must define: inputs, outputs, workspace, and memory layouts
11  * 5. Pure computation - deterministic, no side effects
12  *
13  * After changes: make test && make llamacpp-parity-full
14  *
15  * VIOLATION: Uses memcpy for layout conversion. TODO: Use strided access.
16  *
17  * Part of C-Kernel-Engine v6.6 Fusion Kernels
18  *
19  * FUSES THE ENTIRE BLOCK from Attention output to next layer input:
20  *
21  * Attention(Q, K_cache, V_cache)
22  * │
23  * ▼
24  * Output Projection (attn @ Wo)
25  * │
26  * ▼
27  * + residual_1
28  * │
29  * ▼
30  * RMSNorm
31  * │
32  * ▼
33  * MLP: gate ──► SwiGLU ◄── up
34  * │
35  * ▼
36  * down
37  * │
38  * ▼
39  * + residual_2
40  * │
41  * ▼
42  * hidden_out (ready for next layer)
43  *
44  * NON-FUSED version writes these buffers to DRAM:
45  * - attn_output [embed_dim]
46  * - projected [embed_dim]
47  * - hidden_after_attn [embed_dim]
48  * - normed [embed_dim]
49  * - gate [intermediate_dim]
50  * - up [intermediate_dim]
51  * - swiglu [intermediate_dim]
52  * - mlp_out [embed_dim]
53  * = 8 DRAM round-trips!
54  *
55  * FUSED version: ALL intermediates stay in L1/L2, ZERO DRAM writes
56  *
57  * EXPECTED SPEEDUP: 2-3x for this block
58  */
59 
60 #include <stdint.h>
61 #include <stddef.h>
62 #include <stdlib.h>
63 #include <math.h>
64 #include <string.h>
65 
66 #ifdef __AVX2__
67 #include <immintrin.h>
68 #endif
69 
70 #include "ckernel_quant.h"
71 
72 /* ============================================================================
73  * HELPER: RMSNorm computation (inline, result stays in registers)
74  * ============================================================================ */
75 
76 static inline float compute_rms_scale_internal(const float *x, int n, float eps) {
77  float sum_sq = 0.0f;
78 
79 #ifdef __AVX2__
80  __m256 vsum = _mm256_setzero_ps();
81  int i = 0;
82  for (; i + 7 < n; i += 8) {
83  __m256 vx = _mm256_loadu_ps(x + i);
84  vsum = _mm256_fmadd_ps(vx, vx, vsum);
85  }
86  __m128 vlow = _mm256_castps256_ps128(vsum);
87  __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
88  vlow = _mm_add_ps(vlow, vhigh);
89  vlow = _mm_hadd_ps(vlow, vlow);
90  vlow = _mm_hadd_ps(vlow, vlow);
91  sum_sq = _mm_cvtss_f32(vlow);
92  for (; i < n; i++) {
93  sum_sq += x[i] * x[i];
94  }
95 #else
96  for (int i = 0; i < n; i++) {
97  sum_sq += x[i] * x[i];
98  }
99 #endif
100 
101  float rms = sqrtf(sum_sq / (float)n + eps);
102  return 1.0f / rms;
103 }
104 
105 /* ============================================================================
106  * HELPER: SiLU activation (x * sigmoid(x))
107  * ============================================================================ */
108 
109 static inline float silu_scalar(float x) {
110  return x / (1.0f + expf(-x));
111 }
112 
113 #ifdef __AVX2__
114 static inline __m256 silu_avx2(__m256 x) {
115  /* Approximate sigmoid using fast exp */
116  __m256 neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
117 
118  /* Clamp to avoid overflow */
119  neg_x = _mm256_max_ps(neg_x, _mm256_set1_ps(-88.0f));
120  neg_x = _mm256_min_ps(neg_x, _mm256_set1_ps(88.0f));
121 
122  /* Polynomial approximation for exp(-x) */
123  __m256 one = _mm256_set1_ps(1.0f);
124  __m256 c1 = _mm256_set1_ps(0.5f);
125  __m256 c2 = _mm256_set1_ps(0.166666667f);
126  __m256 c3 = _mm256_set1_ps(0.041666667f);
127  __m256 c4 = _mm256_set1_ps(0.008333333f);
128 
129  __m256 x2 = _mm256_mul_ps(neg_x, neg_x);
130  __m256 x3 = _mm256_mul_ps(x2, neg_x);
131  __m256 x4 = _mm256_mul_ps(x2, x2);
132 
133  __m256 exp_neg = _mm256_add_ps(one, neg_x);
134  exp_neg = _mm256_fmadd_ps(c1, x2, exp_neg);
135  exp_neg = _mm256_fmadd_ps(c2, x3, exp_neg);
136  exp_neg = _mm256_fmadd_ps(c3, x4, exp_neg);
137 
138  /* sigmoid = 1 / (1 + exp(-x)) */
139  __m256 sigmoid = _mm256_div_ps(one, _mm256_add_ps(one, exp_neg));
140 
141  /* silu = x * sigmoid */
142  return _mm256_mul_ps(x, sigmoid);
143 }
144 #endif
145 
146 /* ============================================================================
147  * HELPER: Softmax with online computation (for attention)
148  * ============================================================================ */
149 
150 static void softmax_inplace(float *x, int n) {
151  float max_val = x[0];
152  for (int i = 1; i < n; i++) {
153  if (x[i] > max_val) max_val = x[i];
154  }
155 
156  float sum = 0.0f;
157  for (int i = 0; i < n; i++) {
158  x[i] = expf(x[i] - max_val);
159  sum += x[i];
160  }
161 
162  float inv_sum = 1.0f / sum;
163  for (int i = 0; i < n; i++) {
164  x[i] *= inv_sum;
165  }
166 }
167 
168 /* ============================================================================
169  * MEGA-FUSED KERNEL: Attention + Output + RMSNorm + MLP
170  *
171  * This fuses the entire block from attention to MLP output.
172  * All intermediates stay in L1/L2 cache.
173  * ============================================================================ */
174 
176  /* Attention inputs */
177  const float *q, /* [num_heads * head_dim] query vector */
178  const float *k_cache, /* [seq_len, num_kv_heads * head_dim] K cache */
179  const float *v_cache, /* [seq_len, num_kv_heads * head_dim] V cache */
180  int seq_len, /* Current sequence length */
181  int num_heads,
182  int num_kv_heads,
183  int head_dim,
184  float attn_scale, /* 1/sqrt(head_dim) */
185 
186  /* Output projection */
187  const float *wo, /* [embed_dim, num_heads * head_dim] */
188 
189  /* Residual input */
190  const float *residual_1, /* [embed_dim] input to attention block */
191 
192  /* RMSNorm */
193  const float *rms_weight, /* [embed_dim] */
194  float eps,
195 
196  /* MLP weights (FP32 for this version) */
197  const float *w_gate, /* [intermediate_dim, embed_dim] */
198  const float *w_up, /* [intermediate_dim, embed_dim] */
199  const float *w_down, /* [embed_dim, intermediate_dim] */
200 
201  /* Residual 2 input (usually same as after attention residual) */
202  /* If NULL, uses the hidden_after_attn */
203 
204  /* Dimensions */
205  int embed_dim,
206  int intermediate_dim,
207 
208  /* Output */
209  float *hidden_out /* [embed_dim] output for next layer */
210 ) {
211  const int heads_per_kv = num_heads / num_kv_heads;
212  const int q_dim = num_heads * head_dim;
213  const int kv_dim = num_kv_heads * head_dim;
214 
215  /* Stack buffers - all stay in L1/L2 */
216  float attn_out[4096]; /* Attention output per head, then combined */
217  float hidden_after_attn[4096];
218  float normed[4096];
219  float gate_out[16384]; /* Intermediate dim (e.g., 4864 for Qwen2) */
220  float up_out[16384];
221 
222  if (embed_dim > 4096 || intermediate_dim > 16384) {
223  return; /* TODO: heap allocation for large models */
224  }
225 
226  /* ═══════════════════════════════════════════════════════════════════════
227  * STEP 1: Multi-Head Attention (Q @ K^T -> softmax -> @ V)
228  * ═══════════════════════════════════════════════════════════════════════ */
229 
230  memset(attn_out, 0, q_dim * sizeof(float));
231 
232  for (int h = 0; h < num_heads; h++) {
233  int kv_h = h / heads_per_kv; /* GQA: map query head to KV head */
234 
235  const float *q_head = q + h * head_dim;
236  float *out_head = attn_out + h * head_dim;
237 
238  /* Compute attention scores: Q @ K^T */
239  float scores[8192]; /* Max seq_len */
240  if (seq_len > 8192) return;
241 
242  for (int t = 0; t < seq_len; t++) {
243  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
244  float score = 0.0f;
245  for (int d = 0; d < head_dim; d++) {
246  score += q_head[d] * k_t[d];
247  }
248  scores[t] = score * attn_scale;
249  }
250 
251  /* Softmax */
252  softmax_inplace(scores, seq_len);
253 
254  /* Weighted sum of V: scores @ V */
255  for (int t = 0; t < seq_len; t++) {
256  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
257  float w = scores[t];
258  for (int d = 0; d < head_dim; d++) {
259  out_head[d] += w * v_t[d];
260  }
261  }
262  }
263 
264  /* ═══════════════════════════════════════════════════════════════════════
265  * STEP 2: Output Projection (attn_out @ Wo) + Residual
266  * ═══════════════════════════════════════════════════════════════════════ */
267 
268  for (int i = 0; i < embed_dim; i++) {
269  float sum = 0.0f;
270  const float *wo_row = wo + i * q_dim;
271  for (int j = 0; j < q_dim; j++) {
272  sum += wo_row[j] * attn_out[j];
273  }
274  hidden_after_attn[i] = sum + residual_1[i]; /* Residual add */
275  }
276 
277  /* ═══════════════════════════════════════════════════════════════════════
278  * STEP 3: RMSNorm
279  * ═══════════════════════════════════════════════════════════════════════ */
280 
281  float rms_scale = compute_rms_scale_internal(hidden_after_attn, embed_dim, eps);
282 
283 #ifdef __AVX2__
284  __m256 vscale = _mm256_set1_ps(rms_scale);
285  int i = 0;
286  for (; i + 7 < embed_dim; i += 8) {
287  __m256 vh = _mm256_loadu_ps(hidden_after_attn + i);
288  __m256 vw = _mm256_loadu_ps(rms_weight + i);
289  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
290  _mm256_storeu_ps(normed + i, vn);
291  }
292  for (; i < embed_dim; i++) {
293  normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
294  }
295 #else
296  for (int i = 0; i < embed_dim; i++) {
297  normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
298  }
299 #endif
300 
301  /* ═══════════════════════════════════════════════════════════════════════
302  * STEP 4: MLP Gate + Up projections (can be parallelized)
303  * ═══════════════════════════════════════════════════════════════════════ */
304 
305  /* Gate projection: gate_out = normed @ W_gate^T */
306  for (int i = 0; i < intermediate_dim; i++) {
307  float sum = 0.0f;
308  const float *wg_row = w_gate + i * embed_dim;
309  for (int j = 0; j < embed_dim; j++) {
310  sum += wg_row[j] * normed[j];
311  }
312  gate_out[i] = sum;
313  }
314 
315  /* Up projection: up_out = normed @ W_up^T */
316  for (int i = 0; i < intermediate_dim; i++) {
317  float sum = 0.0f;
318  const float *wu_row = w_up + i * embed_dim;
319  for (int j = 0; j < embed_dim; j++) {
320  sum += wu_row[j] * normed[j];
321  }
322  up_out[i] = sum;
323  }
324 
325  /* ═══════════════════════════════════════════════════════════════════════
326  * STEP 5: SwiGLU activation: silu(gate) * up
327  * ═══════════════════════════════════════════════════════════════════════ */
328 
329 #ifdef __AVX2__
330  i = 0;
331  for (; i + 7 < intermediate_dim; i += 8) {
332  __m256 vg = _mm256_loadu_ps(gate_out + i);
333  __m256 vu = _mm256_loadu_ps(up_out + i);
334  __m256 vsilu = silu_avx2(vg);
335  __m256 vswiglu = _mm256_mul_ps(vsilu, vu);
336  _mm256_storeu_ps(gate_out + i, vswiglu); /* Reuse gate_out buffer */
337  }
338  for (; i < intermediate_dim; i++) {
339  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
340  }
341 #else
342  for (int i = 0; i < intermediate_dim; i++) {
343  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
344  }
345 #endif
346 
347  /* ═══════════════════════════════════════════════════════════════════════
348  * STEP 6: Down projection + Final Residual
349  * ═══════════════════════════════════════════════════════════════════════ */
350 
351  for (int i = 0; i < embed_dim; i++) {
352  float sum = 0.0f;
353  const float *wd_row = w_down + i * intermediate_dim;
354  for (int j = 0; j < intermediate_dim; j++) {
355  sum += wd_row[j] * gate_out[j]; /* gate_out now holds SwiGLU output */
356  }
357  hidden_out[i] = sum + hidden_after_attn[i]; /* Final residual */
358  }
359 }
360 
361 /* ============================================================================
362  * V2: MLP-ONLY FUSED KERNEL with SIMD GEMV
363  *
364  * Key optimizations over v1:
365  * 1. AVX2 SIMD for ALL GEMVs (not just RMSNorm/SwiGLU)
366  * 2. Gate + Up computed TOGETHER (one pass through normed)
367  * 3. Horizontal sums done efficiently
368  *
369  * This isolates the MLP portion for benchmarking.
370  * ============================================================================ */
371 
372 #ifdef __AVX2__
373 /* Inline SIMD GEMV helper - processes one output row */
374 static inline float gemv_fp32_row_avx2(
375  const float *row, /* [K] weight row */
376  const float *x, /* [K] input vector */
377  int K
378 ) {
379  __m256 acc = _mm256_setzero_ps();
380  int k = 0;
381 
382  for (; k + 7 < K; k += 8) {
383  __m256 vw = _mm256_loadu_ps(row + k);
384  __m256 vx = _mm256_loadu_ps(x + k);
385  acc = _mm256_fmadd_ps(vw, vx, acc);
386  }
387 
388  /* Horizontal sum */
389  __m128 vlow = _mm256_castps256_ps128(acc);
390  __m128 vhigh = _mm256_extractf128_ps(acc, 1);
391  vlow = _mm_add_ps(vlow, vhigh);
392  __m128 shuf = _mm_movehdup_ps(vlow);
393  vlow = _mm_add_ps(vlow, shuf);
394  shuf = _mm_movehl_ps(shuf, vlow);
395  vlow = _mm_add_ss(vlow, shuf);
396  float sum = _mm_cvtss_f32(vlow);
397 
398  /* Remainder */
399  for (; k < K; k++) {
400  sum += row[k] * x[k];
401  }
402 
403  return sum;
404 }
405 #endif
406 
408  /* Input (after attention + residual) */
409  const float *hidden_in, /* [embed_dim] */
410 
411  /* RMSNorm */
412  const float *rms_weight, /* [embed_dim] */
413  float eps,
414 
415  /* MLP weights (FP32) */
416  const float *w_gate, /* [intermediate_dim, embed_dim] */
417  const float *w_up, /* [intermediate_dim, embed_dim] */
418  const float *w_down, /* [embed_dim, intermediate_dim] */
419 
420  /* Dimensions */
421  int embed_dim,
422  int intermediate_dim,
423 
424  /* Output */
425  float *hidden_out /* [embed_dim] */
426 ) {
427  /* Stack buffers - sized for typical models */
428  float normed[4096];
429  float swiglu[16384]; /* intermediate_dim */
430 
431  if (embed_dim > 4096 || intermediate_dim > 16384) {
432  return; /* TODO: handle larger models */
433  }
434 
435  /* ═══════════════════════════════════════════════════════════════════════
436  * STEP 1: RMSNorm (SIMD)
437  * ═══════════════════════════════════════════════════════════════════════ */
438 
439  float rms_scale = compute_rms_scale_internal(hidden_in, embed_dim, eps);
440 
441 #ifdef __AVX2__
442  __m256 vscale = _mm256_set1_ps(rms_scale);
443  int i = 0;
444  for (; i + 7 < embed_dim; i += 8) {
445  __m256 vh = _mm256_loadu_ps(hidden_in + i);
446  __m256 vw = _mm256_loadu_ps(rms_weight + i);
447  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
448  _mm256_storeu_ps(normed + i, vn);
449  }
450  for (; i < embed_dim; i++) {
451  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
452  }
453 #else
454  for (int i = 0; i < embed_dim; i++) {
455  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
456  }
457 #endif
458 
459  /* ═══════════════════════════════════════════════════════════════════════
460  * STEP 2: Gate + Up projections with TRUE FUSION + SwiGLU
461  *
462  * Key insight: Compute gate[i] and up[i] together, then immediately
463  * apply SwiGLU. This eliminates separate gate_out and up_out buffers.
464  * ═══════════════════════════════════════════════════════════════════════ */
465 
466 #ifdef __AVX2__
467  for (int j = 0; j < intermediate_dim; j++) {
468  /* Compute gate and up for output j using SIMD GEMV */
469  const float *wg_row = w_gate + j * embed_dim;
470  const float *wu_row = w_up + j * embed_dim;
471 
472  __m256 gate_acc = _mm256_setzero_ps();
473  __m256 up_acc = _mm256_setzero_ps();
474 
475  int k = 0;
476  for (; k + 7 < embed_dim; k += 8) {
477  __m256 vn = _mm256_loadu_ps(normed + k);
478  __m256 vwg = _mm256_loadu_ps(wg_row + k);
479  __m256 vwu = _mm256_loadu_ps(wu_row + k);
480 
481  gate_acc = _mm256_fmadd_ps(vwg, vn, gate_acc);
482  up_acc = _mm256_fmadd_ps(vwu, vn, up_acc);
483  }
484 
485  /* Horizontal sums */
486  __m128 glow = _mm256_castps256_ps128(gate_acc);
487  __m128 ghigh = _mm256_extractf128_ps(gate_acc, 1);
488  glow = _mm_add_ps(glow, ghigh);
489  __m128 gshuf = _mm_movehdup_ps(glow);
490  glow = _mm_add_ps(glow, gshuf);
491  gshuf = _mm_movehl_ps(gshuf, glow);
492  glow = _mm_add_ss(glow, gshuf);
493  float gate_val = _mm_cvtss_f32(glow);
494 
495  __m128 ulow = _mm256_castps256_ps128(up_acc);
496  __m128 uhigh = _mm256_extractf128_ps(up_acc, 1);
497  ulow = _mm_add_ps(ulow, uhigh);
498  __m128 ushuf = _mm_movehdup_ps(ulow);
499  ulow = _mm_add_ps(ulow, ushuf);
500  ushuf = _mm_movehl_ps(ushuf, ulow);
501  ulow = _mm_add_ss(ulow, ushuf);
502  float up_val = _mm_cvtss_f32(ulow);
503 
504  /* Remainder */
505  for (; k < embed_dim; k++) {
506  gate_val += wg_row[k] * normed[k];
507  up_val += wu_row[k] * normed[k];
508  }
509 
510  /* Fused SwiGLU: silu(gate) * up */
511  swiglu[j] = silu_scalar(gate_val) * up_val;
512  }
513 #else
514  for (int j = 0; j < intermediate_dim; j++) {
515  const float *wg_row = w_gate + j * embed_dim;
516  const float *wu_row = w_up + j * embed_dim;
517  float gate_val = 0.0f, up_val = 0.0f;
518 
519  for (int k = 0; k < embed_dim; k++) {
520  gate_val += wg_row[k] * normed[k];
521  up_val += wu_row[k] * normed[k];
522  }
523 
524  swiglu[j] = silu_scalar(gate_val) * up_val;
525  }
526 #endif
527 
528  /* ═══════════════════════════════════════════════════════════════════════
529  * STEP 3: Down projection + Residual (SIMD GEMV)
530  * ═══════════════════════════════════════════════════════════════════════ */
531 
532 #ifdef __AVX2__
533  for (int j = 0; j < embed_dim; j++) {
534  float sum = gemv_fp32_row_avx2(w_down + j * intermediate_dim, swiglu, intermediate_dim);
535  hidden_out[j] = sum + hidden_in[j]; /* Residual */
536  }
537 #else
538  for (int j = 0; j < embed_dim; j++) {
539  float sum = 0.0f;
540  const float *wd_row = w_down + j * intermediate_dim;
541  for (int k = 0; k < intermediate_dim; k++) {
542  sum += wd_row[k] * swiglu[k];
543  }
544  hidden_out[j] = sum + hidden_in[j];
545  }
546 #endif
547 }
548 
549 
550 /* ============================================================================
551  * V3: MLP with SIMD GEMV but SEQUENTIAL weight access
552  *
553  * Key insight from v2 benchmark: fusing gate+up HURTS performance because
554  * interleaved weight loading destroys cache prefetch patterns.
555  *
556  * v3 approach:
557  * 1. Use SIMD GEMV for all projections
558  * 2. Keep SEQUENTIAL weight access (gate first, then up)
559  * 3. Still fuse SwiGLU immediately after projections
560  *
561  * This should be faster than v2 AND faster than scalar separate.
562  * ============================================================================ */
563 
565  const float *hidden_in,
566  const float *rms_weight,
567  float eps,
568  const float *w_gate,
569  const float *w_up,
570  const float *w_down,
571  int embed_dim,
572  int intermediate_dim,
573  float *hidden_out
574 ) {
575  /* Stack buffers */
576  float normed[4096];
577  float gate_out[16384];
578  float swiglu[16384];
579 
580  if (embed_dim > 4096 || intermediate_dim > 16384) {
581  return;
582  }
583 
584  /* ═══════════════════════════════════════════════════════════════════════
585  * STEP 1: RMSNorm (SIMD)
586  * ═══════════════════════════════════════════════════════════════════════ */
587 
588  float rms_scale = compute_rms_scale_internal(hidden_in, embed_dim, eps);
589 
590 #ifdef __AVX2__
591  __m256 vscale = _mm256_set1_ps(rms_scale);
592  int i = 0;
593  for (; i + 7 < embed_dim; i += 8) {
594  __m256 vh = _mm256_loadu_ps(hidden_in + i);
595  __m256 vw = _mm256_loadu_ps(rms_weight + i);
596  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
597  _mm256_storeu_ps(normed + i, vn);
598  }
599  for (; i < embed_dim; i++) {
600  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
601  }
602 #else
603  for (int i = 0; i < embed_dim; i++) {
604  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
605  }
606 #endif
607 
608  /* ═══════════════════════════════════════════════════════════════════════
609  * STEP 2: Gate projection (SIMD GEMV, sequential weight access)
610  * ═══════════════════════════════════════════════════════════════════════ */
611 
612 #ifdef __AVX2__
613  for (int j = 0; j < intermediate_dim; j++) {
614  gate_out[j] = gemv_fp32_row_avx2(w_gate + j * embed_dim, normed, embed_dim);
615  }
616 #else
617  for (int j = 0; j < intermediate_dim; j++) {
618  float sum = 0.0f;
619  const float *wg_row = w_gate + j * embed_dim;
620  for (int k = 0; k < embed_dim; k++) {
621  sum += wg_row[k] * normed[k];
622  }
623  gate_out[j] = sum;
624  }
625 #endif
626 
627  /* ═══════════════════════════════════════════════════════════════════════
628  * STEP 3: Up projection + FUSED SwiGLU (SIMD GEMV, sequential access)
629  *
630  * Key: compute up[j], then immediately apply SwiGLU with gate[j].
631  * This avoids storing the full up_out buffer.
632  * ═══════════════════════════════════════════════════════════════════════ */
633 
634 #ifdef __AVX2__
635  for (int j = 0; j < intermediate_dim; j++) {
636  float up_val = gemv_fp32_row_avx2(w_up + j * embed_dim, normed, embed_dim);
637  /* Fused SwiGLU: silu(gate) * up */
638  swiglu[j] = silu_scalar(gate_out[j]) * up_val;
639  }
640 #else
641  for (int j = 0; j < intermediate_dim; j++) {
642  float up_val = 0.0f;
643  const float *wu_row = w_up + j * embed_dim;
644  for (int k = 0; k < embed_dim; k++) {
645  up_val += wu_row[k] * normed[k];
646  }
647  swiglu[j] = silu_scalar(gate_out[j]) * up_val;
648  }
649 #endif
650 
651  /* ═══════════════════════════════════════════════════════════════════════
652  * STEP 4: Down projection + Residual (SIMD GEMV)
653  * ═══════════════════════════════════════════════════════════════════════ */
654 
655 #ifdef __AVX2__
656  for (int j = 0; j < embed_dim; j++) {
657  float sum = gemv_fp32_row_avx2(w_down + j * intermediate_dim, swiglu, intermediate_dim);
658  hidden_out[j] = sum + hidden_in[j];
659  }
660 #else
661  for (int j = 0; j < embed_dim; j++) {
662  float sum = 0.0f;
663  const float *wd_row = w_down + j * intermediate_dim;
664  for (int k = 0; k < intermediate_dim; k++) {
665  sum += wd_row[k] * swiglu[k];
666  }
667  hidden_out[j] = sum + hidden_in[j];
668  }
669 #endif
670 }
671 
672 
673 /* ============================================================================
674  * SEPARATE MLP (for benchmarking comparison)
675  *
676  * Same operations but as separate function calls.
677  * ============================================================================ */
678 
680  const float *hidden_in,
681  const float *rms_weight,
682  float eps,
683  const float *w_gate,
684  const float *w_up,
685  const float *w_down,
686  float *normed_buf, /* [embed_dim] caller-provided */
687  float *gate_buf, /* [intermediate_dim] caller-provided */
688  float *up_buf, /* [intermediate_dim] caller-provided */
689  int embed_dim,
690  int intermediate_dim,
691  float *hidden_out
692 ) {
693  /* Step 1: RMSNorm */
694  float rms_scale = compute_rms_scale_internal(hidden_in, embed_dim, eps);
695  for (int i = 0; i < embed_dim; i++) {
696  normed_buf[i] = hidden_in[i] * rms_weight[i] * rms_scale;
697  }
698 
699  /* Step 2: Gate projection */
700  for (int j = 0; j < intermediate_dim; j++) {
701  float sum = 0.0f;
702  const float *wg_row = w_gate + j * embed_dim;
703  for (int k = 0; k < embed_dim; k++) {
704  sum += wg_row[k] * normed_buf[k];
705  }
706  gate_buf[j] = sum;
707  }
708 
709  /* Step 3: Up projection */
710  for (int j = 0; j < intermediate_dim; j++) {
711  float sum = 0.0f;
712  const float *wu_row = w_up + j * embed_dim;
713  for (int k = 0; k < embed_dim; k++) {
714  sum += wu_row[k] * normed_buf[k];
715  }
716  up_buf[j] = sum;
717  }
718 
719  /* Step 4: SwiGLU */
720  for (int j = 0; j < intermediate_dim; j++) {
721  gate_buf[j] = silu_scalar(gate_buf[j]) * up_buf[j];
722  }
723 
724  /* Step 5: Down projection + Residual */
725  for (int j = 0; j < embed_dim; j++) {
726  float sum = 0.0f;
727  const float *wd_row = w_down + j * intermediate_dim;
728  for (int k = 0; k < intermediate_dim; k++) {
729  sum += wd_row[k] * gate_buf[k];
730  }
731  hidden_out[j] = sum + hidden_in[j];
732  }
733 }
734 
735 
736 /* ============================================================================
737  * Q4_K VERSION: Attention + Output + RMSNorm + MLP with quantized weights
738  *
739  * All MLP weights are Q4_K quantized.
740  * ============================================================================ */
741 
743  /* Attention inputs */
744  const float *q, /* [num_heads * head_dim] */
745  const float *k_cache, /* [seq_len, num_kv_heads * head_dim] */
746  const float *v_cache, /* [seq_len, num_kv_heads * head_dim] */
747  int seq_len,
748  int num_heads,
749  int num_kv_heads,
750  int head_dim,
751  float attn_scale,
752 
753  /* Output projection (Q4_K) */
754  const void *wo,
755 
756  /* Residual */
757  const float *residual_1,
758 
759  /* RMSNorm */
760  const float *rms_weight,
761  float eps,
762 
763  /* MLP weights (Q4_K) */
764  const void *w_gate,
765  const void *w_up,
766  const void *w_down,
767 
768  /* Dimensions */
769  int embed_dim,
770  int intermediate_dim,
771 
772  /* Output */
773  float *hidden_out
774 ) {
775  const int heads_per_kv = num_heads / num_kv_heads;
776  const int q_dim = num_heads * head_dim;
777  const int kv_dim = num_kv_heads * head_dim;
778 
779  /* Stack buffers */
780  float attn_out[4096];
781  float hidden_after_attn[4096];
782  float normed[4096];
783  float mlp_out[4096];
784 
785  if (embed_dim > 4096) return;
786 
787  /* ═══════════════════════════════════════════════════════════════════════
788  * STEP 1: Multi-Head Attention (same as FP32 version)
789  * ═══════════════════════════════════════════════════════════════════════ */
790 
791  memset(attn_out, 0, q_dim * sizeof(float));
792 
793  for (int h = 0; h < num_heads; h++) {
794  int kv_h = h / heads_per_kv;
795  const float *q_head = q + h * head_dim;
796  float *out_head = attn_out + h * head_dim;
797 
798  float scores[8192];
799  if (seq_len > 8192) return;
800 
801  for (int t = 0; t < seq_len; t++) {
802  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
803  float score = 0.0f;
804  for (int d = 0; d < head_dim; d++) {
805  score += q_head[d] * k_t[d];
806  }
807  scores[t] = score * attn_scale;
808  }
809 
810  softmax_inplace(scores, seq_len);
811 
812  for (int t = 0; t < seq_len; t++) {
813  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
814  float w = scores[t];
815  for (int d = 0; d < head_dim; d++) {
816  out_head[d] += w * v_t[d];
817  }
818  }
819  }
820 
821  /* ═══════════════════════════════════════════════════════════════════════
822  * STEP 2: Output Projection (Q4_K) + Residual
823  * ═══════════════════════════════════════════════════════════════════════ */
824 
825  extern void gemv_q4_k(float *y, const void *W, const float *x, int M, int K);
826 
827  gemv_q4_k(hidden_after_attn, wo, attn_out, embed_dim, q_dim);
828 
829  /* Add residual */
830  for (int i = 0; i < embed_dim; i++) {
831  hidden_after_attn[i] += residual_1[i];
832  }
833 
834  /* ═══════════════════════════════════════════════════════════════════════
835  * STEP 3: RMSNorm (same as before)
836  * ═══════════════════════════════════════════════════════════════════════ */
837 
838  float rms_scale = compute_rms_scale_internal(hidden_after_attn, embed_dim, eps);
839 
840  for (int i = 0; i < embed_dim; i++) {
841  normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
842  }
843 
844  /* ═══════════════════════════════════════════════════════════════════════
845  * STEP 4-6: MLP with Q4_K weights (inline implementation)
846  *
847  * gate_out = normed @ W_gate
848  * up_out = normed @ W_up
849  * swiglu = silu(gate_out) * up_out
850  * mlp_out = swiglu @ W_down
851  * ═══════════════════════════════════════════════════════════════════════ */
852 
853  float gate_out[16384];
854  float up_out[16384];
855 
856  if (intermediate_dim > 16384) return;
857 
858  /* Gate projection */
859  gemv_q4_k(gate_out, w_gate, normed, intermediate_dim, embed_dim);
860 
861  /* Up projection */
862  gemv_q4_k(up_out, w_up, normed, intermediate_dim, embed_dim);
863 
864  /* SwiGLU: silu(gate) * up */
865 #ifdef __AVX2__
866  int i = 0;
867  for (; i + 7 < intermediate_dim; i += 8) {
868  __m256 vg = _mm256_loadu_ps(gate_out + i);
869  __m256 vu = _mm256_loadu_ps(up_out + i);
870  __m256 vsilu = silu_avx2(vg);
871  __m256 vswiglu = _mm256_mul_ps(vsilu, vu);
872  _mm256_storeu_ps(gate_out + i, vswiglu);
873  }
874  for (; i < intermediate_dim; i++) {
875  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
876  }
877 #else
878  for (int i = 0; i < intermediate_dim; i++) {
879  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
880  }
881 #endif
882 
883  /* Down projection */
884  gemv_q4_k(mlp_out, w_down, gate_out, embed_dim, intermediate_dim);
885 
886  /* Final residual add */
887  for (int i = 0; i < embed_dim; i++) {
888  hidden_out[i] = mlp_out[i] + hidden_after_attn[i];
889  }
890 }
891 
892 /* ============================================================================
893  * COMPLETE LAYER FUSION: Attention → MLP → Next Layer's QKV
894  *
895  * This is the TRUE mega-fusion: from one layer's attention output all the
896  * way to the next layer's Q (ready for attention) + K,V written to cache.
897  *
898  * The hidden state NEVER touches DRAM between layers!
899  * ============================================================================ */
900 
902  /* === CURRENT LAYER ATTENTION INPUTS === */
903  const float *q, /* [num_heads * head_dim] query for this layer */
904  const float *k_cache, /* [seq_len, num_kv_heads * head_dim] */
905  const float *v_cache, /* [seq_len, num_kv_heads * head_dim] */
906  int seq_len,
907  float attn_scale,
908 
909  /* === CURRENT LAYER WEIGHTS (Q4_K) === */
910  const void *wo, /* Output projection */
911  const float *rms_weight_mlp, /* RMSNorm before MLP */
912  const void *w_gate, /* MLP gate */
913  const void *w_up, /* MLP up */
914  const void *w_down, /* MLP down */
915 
916  /* === NEXT LAYER WEIGHTS (Q4_K) === */
917  const float *rms_weight_attn, /* RMSNorm before next attention */
918  const void *wq_next, /* Next layer Q projection */
919  const void *wk_next, /* Next layer K projection */
920  const void *wv_next, /* Next layer V projection */
921 
922  /* === RESIDUAL INPUT === */
923  const float *residual_in, /* [embed_dim] input to this layer */
924 
925  /* === DIMENSIONS === */
926  int embed_dim,
927  int intermediate_dim,
928  int num_heads,
929  int num_kv_heads,
930  int head_dim,
931  float eps,
932 
933  /* === OUTPUTS === */
934  float *q_next, /* [num_heads * head_dim] Q for next layer */
935  float *k_next, /* [num_kv_heads * head_dim] K to write to cache */
936  float *v_next, /* [num_kv_heads * head_dim] V to write to cache */
937  float *hidden_out /* [embed_dim] hidden state (for final layer) */
938 ) {
939  extern void gemv_q4_k(float *y, const void *W, const float *x, int M, int K);
940 
941  const int heads_per_kv = num_heads / num_kv_heads;
942  const int q_dim = num_heads * head_dim;
943  const int kv_dim = num_kv_heads * head_dim;
944 
945  /* All intermediate buffers on stack - stay in L1/L2
946  * hidden_out is the final output buffer - we write to it directly! */
947  float attn_out[4096];
948  float hidden_after_attn[4096];
949  float normed_mlp[4096];
950  float gate_out[16384];
951  float up_out[16384];
952  /* NOTE: No hidden_after_mlp buffer - we output directly to hidden_out */
953  float normed_attn[4096];
954 
955  if (embed_dim > 4096 || intermediate_dim > 16384) return;
956 
957  /* ═══════════════════════════════════════════════════════════════════════
958  * STEP 1: Multi-Head Attention
959  * ═══════════════════════════════════════════════════════════════════════ */
960 
961  memset(attn_out, 0, q_dim * sizeof(float));
962 
963  for (int h = 0; h < num_heads; h++) {
964  int kv_h = h / heads_per_kv;
965  const float *q_head = q + h * head_dim;
966  float *out_head = attn_out + h * head_dim;
967 
968  float scores[8192];
969  if (seq_len > 8192) return;
970 
971  for (int t = 0; t < seq_len; t++) {
972  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
973  float score = 0.0f;
974  for (int d = 0; d < head_dim; d++) {
975  score += q_head[d] * k_t[d];
976  }
977  scores[t] = score * attn_scale;
978  }
979 
980  /* Softmax */
981  float max_score = scores[0];
982  for (int t = 1; t < seq_len; t++) {
983  if (scores[t] > max_score) max_score = scores[t];
984  }
985  float sum_exp = 0.0f;
986  for (int t = 0; t < seq_len; t++) {
987  scores[t] = expf(scores[t] - max_score);
988  sum_exp += scores[t];
989  }
990  float inv_sum = 1.0f / sum_exp;
991  for (int t = 0; t < seq_len; t++) {
992  scores[t] *= inv_sum;
993  }
994 
995  /* Weighted sum of V */
996  for (int t = 0; t < seq_len; t++) {
997  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
998  float w = scores[t];
999  for (int d = 0; d < head_dim; d++) {
1000  out_head[d] += w * v_t[d];
1001  }
1002  }
1003  }
1004 
1005  /* ═══════════════════════════════════════════════════════════════════════
1006  * STEP 2: Output Projection (Q4_K) + Residual
1007  * ═══════════════════════════════════════════════════════════════════════ */
1008 
1009  gemv_q4_k(hidden_after_attn, wo, attn_out, embed_dim, q_dim);
1010 
1011  for (int i = 0; i < embed_dim; i++) {
1012  hidden_after_attn[i] += residual_in[i];
1013  }
1014 
1015  /* ═══════════════════════════════════════════════════════════════════════
1016  * STEP 3: RMSNorm (for MLP)
1017  * ═══════════════════════════════════════════════════════════════════════ */
1018 
1019  float sum_sq = 0.0f;
1020  for (int i = 0; i < embed_dim; i++) {
1021  sum_sq += hidden_after_attn[i] * hidden_after_attn[i];
1022  }
1023  float rms_scale = 1.0f / sqrtf(sum_sq / embed_dim + eps);
1024 
1025  for (int i = 0; i < embed_dim; i++) {
1026  normed_mlp[i] = hidden_after_attn[i] * rms_weight_mlp[i] * rms_scale;
1027  }
1028 
1029  /* ═══════════════════════════════════════════════════════════════════════
1030  * STEP 4-6: MLP (gate + up + SwiGLU + down)
1031  * ═══════════════════════════════════════════════════════════════════════ */
1032 
1033  gemv_q4_k(gate_out, w_gate, normed_mlp, intermediate_dim, embed_dim);
1034  gemv_q4_k(up_out, w_up, normed_mlp, intermediate_dim, embed_dim);
1035 
1036  /* SwiGLU: silu(gate) * up */
1037  for (int i = 0; i < intermediate_dim; i++) {
1038  float g = gate_out[i];
1039  float silu_g = g / (1.0f + expf(-g));
1040  gate_out[i] = silu_g * up_out[i];
1041  }
1042 
1043  /* Down projection - output DIRECTLY to hidden_out (no intermediate buffer!) */
1044  gemv_q4_k(hidden_out, w_down, gate_out, embed_dim, intermediate_dim);
1045 
1046  /* MLP residual - hidden_out now contains the final hidden state */
1047  for (int i = 0; i < embed_dim; i++) {
1048  hidden_out[i] += hidden_after_attn[i];
1049  }
1050 
1051  /* ═══════════════════════════════════════════════════════════════════════
1052  * STEP 7: RMSNorm (for NEXT layer's attention)
1053  * Read from hidden_out (already contains final hidden state)
1054  * ═══════════════════════════════════════════════════════════════════════ */
1055 
1056  sum_sq = 0.0f;
1057  for (int i = 0; i < embed_dim; i++) {
1058  sum_sq += hidden_out[i] * hidden_out[i];
1059  }
1060  rms_scale = 1.0f / sqrtf(sum_sq / embed_dim + eps);
1061 
1062  for (int i = 0; i < embed_dim; i++) {
1063  normed_attn[i] = hidden_out[i] * rms_weight_attn[i] * rms_scale;
1064  }
1065 
1066  /* ═══════════════════════════════════════════════════════════════════════
1067  * STEP 8: NEXT LAYER's Q, K, V Projections
1068  *
1069  * Q goes to caller (for attention computation)
1070  * K, V go to KV cache (DRAM write - this is intentional!)
1071  * ═══════════════════════════════════════════════════════════════════════ */
1072 
1073  gemv_q4_k(q_next, wq_next, normed_attn, q_dim, embed_dim);
1074  gemv_q4_k(k_next, wk_next, normed_attn, kv_dim, embed_dim);
1075  gemv_q4_k(v_next, wv_next, normed_attn, kv_dim, embed_dim);
1076 
1077  /* hidden_out already contains the final hidden state - no memcpy needed! */
1078 }
1079 
1080 /* ============================================================================
1081  * NON-FUSED REFERENCE: For benchmarking comparison
1082  * ============================================================================ */
1083 
1085  const float *q, const float *k_cache, const float *v_cache,
1086  int seq_len, int num_heads, int num_kv_heads, int head_dim,
1087  float attn_scale,
1088  const float *wo, const float *residual_1,
1089  const float *rms_weight, float eps,
1090  const float *w_gate, const float *w_up, const float *w_down,
1091  int embed_dim, int intermediate_dim,
1092  /* Intermediate buffers - DRAM traffic! */
1093  float *attn_out_buf,
1094  float *hidden_after_attn_buf,
1095  float *normed_buf,
1096  float *gate_buf,
1097  float *up_buf,
1098  float *mlp_out_buf,
1099  /* Output */
1100  float *hidden_out
1101 ) {
1102  /* This version writes all intermediates to the provided buffers,
1103  * simulating non-fused execution with DRAM traffic */
1104 
1105  const int heads_per_kv = num_heads / num_kv_heads;
1106  const int q_dim = num_heads * head_dim;
1107  const int kv_dim = num_kv_heads * head_dim;
1108 
1109  /* Step 1: Attention */
1110  memset(attn_out_buf, 0, q_dim * sizeof(float));
1111 
1112  /* Stack-allocated scores buffer (no malloc!) */
1113  float scores[8192]; /* Max seq_len supported */
1114  if (seq_len > 8192) return;
1115 
1116  for (int h = 0; h < num_heads; h++) {
1117  int kv_h = h / heads_per_kv;
1118  const float *q_head = q + h * head_dim;
1119  float *out_head = attn_out_buf + h * head_dim;
1120 
1121  for (int t = 0; t < seq_len; t++) {
1122  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
1123  float score = 0.0f;
1124  for (int d = 0; d < head_dim; d++) {
1125  score += q_head[d] * k_t[d];
1126  }
1127  scores[t] = score * attn_scale;
1128  }
1129 
1130  softmax_inplace(scores, seq_len);
1131 
1132  for (int t = 0; t < seq_len; t++) {
1133  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
1134  float w = scores[t];
1135  for (int d = 0; d < head_dim; d++) {
1136  out_head[d] += w * v_t[d];
1137  }
1138  }
1139  }
1140 
1141  /* Step 2: Output projection + residual -> DRAM write */
1142  for (int i = 0; i < embed_dim; i++) {
1143  float sum = 0.0f;
1144  const float *wo_row = wo + i * q_dim;
1145  for (int j = 0; j < q_dim; j++) {
1146  sum += wo_row[j] * attn_out_buf[j];
1147  }
1148  hidden_after_attn_buf[i] = sum + residual_1[i];
1149  }
1150 
1151  /* Step 3: RMSNorm -> DRAM write */
1152  float rms_scale = compute_rms_scale_internal(hidden_after_attn_buf, embed_dim, eps);
1153  for (int i = 0; i < embed_dim; i++) {
1154  normed_buf[i] = hidden_after_attn_buf[i] * rms_weight[i] * rms_scale;
1155  }
1156 
1157  /* Step 4: Gate projection -> DRAM write */
1158  for (int i = 0; i < intermediate_dim; i++) {
1159  float sum = 0.0f;
1160  const float *wg_row = w_gate + i * embed_dim;
1161  for (int j = 0; j < embed_dim; j++) {
1162  sum += wg_row[j] * normed_buf[j];
1163  }
1164  gate_buf[i] = sum;
1165  }
1166 
1167  /* Step 5: Up projection -> DRAM write */
1168  for (int i = 0; i < intermediate_dim; i++) {
1169  float sum = 0.0f;
1170  const float *wu_row = w_up + i * embed_dim;
1171  for (int j = 0; j < embed_dim; j++) {
1172  sum += wu_row[j] * normed_buf[j];
1173  }
1174  up_buf[i] = sum;
1175  }
1176 
1177  /* Step 6: SwiGLU (in-place in gate_buf) */
1178  for (int i = 0; i < intermediate_dim; i++) {
1179  gate_buf[i] = silu_scalar(gate_buf[i]) * up_buf[i];
1180  }
1181 
1182  /* Step 7: Down projection -> DRAM write */
1183  for (int i = 0; i < embed_dim; i++) {
1184  float sum = 0.0f;
1185  const float *wd_row = w_down + i * intermediate_dim;
1186  for (int j = 0; j < intermediate_dim; j++) {
1187  sum += wd_row[j] * gate_buf[j];
1188  }
1189  mlp_out_buf[i] = sum;
1190  }
1191 
1192  /* Step 8: Final residual */
1193  for (int i = 0; i < embed_dim; i++) {
1194  hidden_out[i] = mlp_out_buf[i] + hidden_after_attn_buf[i];
1195  }
1196 }
void attention_mlp_fused_q4k(const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const void *wo, const float *residual_1, const float *rms_weight, float eps, const void *w_gate, const void *w_up, const void *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
void mlp_fused_fp32_v2(const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
static float silu_scalar(float x)
void mlp_separate_fp32(const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, float *normed_buf, float *gate_buf, float *up_buf, int embed_dim, int intermediate_dim, float *hidden_out)
static float compute_rms_scale_internal(const float *x, int n, float eps)
void mlp_fused_fp32_v3(const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
static void softmax_inplace(float *x, int n)
void layer_fused_attn_mlp_qkv_q4k(const float *q, const float *k_cache, const float *v_cache, int seq_len, float attn_scale, const void *wo, const float *rms_weight_mlp, const void *w_gate, const void *w_up, const void *w_down, const float *rms_weight_attn, const void *wq_next, const void *wk_next, const void *wv_next, const float *residual_in, int embed_dim, int intermediate_dim, int num_heads, int num_kv_heads, int head_dim, float eps, float *q_next, float *k_next, float *v_next, float *hidden_out)
void attention_mlp_separate_fp32(const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *attn_out_buf, float *hidden_after_attn_buf, float *normed_buf, float *gate_buf, float *up_buf, float *mlp_out_buf, float *hidden_out)
void attention_mlp_fused_fp32(const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
Quantization block structures for weight-only quantization.
int32_t float * score
Definition: tokenizer.h:327