← Back to C-Kernel-Engine Docs Doxygen Source Documentation
attention_mlp_fused.c File Reference

Mega-Fused Attention + MLP Block. More...

#include <stdint.h>
#include <stddef.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

void attention_mlp_fused_fp32 (const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
 
void attention_mlp_fused_q4k (const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const void *wo, const float *residual_1, const float *rms_weight, float eps, const void *w_gate, const void *w_up, const void *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
 
void attention_mlp_separate_fp32 (const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *attn_out_buf, float *hidden_after_attn_buf, float *normed_buf, float *gate_buf, float *up_buf, float *mlp_out_buf, float *hidden_out)
 
static float compute_rms_scale_internal (const float *x, int n, float eps)
 
void layer_fused_attn_mlp_qkv_q4k (const float *q, const float *k_cache, const float *v_cache, int seq_len, float attn_scale, const void *wo, const float *rms_weight_mlp, const void *w_gate, const void *w_up, const void *w_down, const float *rms_weight_attn, const void *wq_next, const void *wk_next, const void *wv_next, const float *residual_in, int embed_dim, int intermediate_dim, int num_heads, int num_kv_heads, int head_dim, float eps, float *q_next, float *k_next, float *v_next, float *hidden_out)
 
void mlp_fused_fp32_v2 (const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
 
void mlp_fused_fp32_v3 (const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out)
 
void mlp_separate_fp32 (const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, float *normed_buf, float *gate_buf, float *up_buf, int embed_dim, int intermediate_dim, float *hidden_out)
 
static float silu_scalar (float x)
 
static void softmax_inplace (float *x, int n)
 

Detailed Description

Mega-Fused Attention + MLP Block.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. NO memcpy for layout - use strided access, not copies
  4. API must define: inputs, outputs, workspace, and memory layouts
  5. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

VIOLATION: Uses memcpy for layout conversion. TODO: Use strided access.

Part of C-Kernel-Engine v6.6 Fusion Kernels

FUSES THE ENTIRE BLOCK from Attention output to next layer input:

Attention(Q, K_cache, V_cache) │ ▼ Output Projection (attn @ Wo) │ ▼

  • residual_1 │ ▼ RMSNorm │ ▼ MLP: gate ──► SwiGLU ◄── up │ ▼ down │ ▼
  • residual_2 │ ▼ hidden_out (ready for next layer)

NON-FUSED version writes these buffers to DRAM:

  • attn_output [embed_dim]
  • projected [embed_dim]
  • hidden_after_attn [embed_dim]
  • normed [embed_dim]
  • gate [intermediate_dim]
  • up [intermediate_dim]
  • swiglu [intermediate_dim]
  • mlp_out [embed_dim] = 8 DRAM round-trips!

FUSED version: ALL intermediates stay in L1/L2, ZERO DRAM writes

EXPECTED SPEEDUP: 2-3x for this block

Definition in file attention_mlp_fused.c.

Function Documentation

◆ attention_mlp_fused_fp32()

void attention_mlp_fused_fp32 ( const float *  q,
const float *  k_cache,
const float *  v_cache,
int  seq_len,
int  num_heads,
int  num_kv_heads,
int  head_dim,
float  attn_scale,
const float *  wo,
const float *  residual_1,
const float *  rms_weight,
float  eps,
const float *  w_gate,
const float *  w_up,
const float *  w_down,
int  embed_dim,
int  intermediate_dim,
float *  hidden_out 
)

Definition at line 175 of file attention_mlp_fused.c.

210  {
211  const int heads_per_kv = num_heads / num_kv_heads;
212  const int q_dim = num_heads * head_dim;
213  const int kv_dim = num_kv_heads * head_dim;
214 
215  /* Stack buffers - all stay in L1/L2 */
216  float attn_out[4096]; /* Attention output per head, then combined */
217  float hidden_after_attn[4096];
218  float normed[4096];
219  float gate_out[16384]; /* Intermediate dim (e.g., 4864 for Qwen2) */
220  float up_out[16384];
221 
222  if (embed_dim > 4096 || intermediate_dim > 16384) {
223  return; /* TODO: heap allocation for large models */
224  }
225 
226  /* ═══════════════════════════════════════════════════════════════════════
227  * STEP 1: Multi-Head Attention (Q @ K^T -> softmax -> @ V)
228  * ═══════════════════════════════════════════════════════════════════════ */
229 
230  memset(attn_out, 0, q_dim * sizeof(float));
231 
232  for (int h = 0; h < num_heads; h++) {
233  int kv_h = h / heads_per_kv; /* GQA: map query head to KV head */
234 
235  const float *q_head = q + h * head_dim;
236  float *out_head = attn_out + h * head_dim;
237 
238  /* Compute attention scores: Q @ K^T */
239  float scores[8192]; /* Max seq_len */
240  if (seq_len > 8192) return;
241 
242  for (int t = 0; t < seq_len; t++) {
243  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
244  float score = 0.0f;
245  for (int d = 0; d < head_dim; d++) {
246  score += q_head[d] * k_t[d];
247  }
248  scores[t] = score * attn_scale;
249  }
250 
251  /* Softmax */
252  softmax_inplace(scores, seq_len);
253 
254  /* Weighted sum of V: scores @ V */
255  for (int t = 0; t < seq_len; t++) {
256  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
257  float w = scores[t];
258  for (int d = 0; d < head_dim; d++) {
259  out_head[d] += w * v_t[d];
260  }
261  }
262  }
263 
264  /* ═══════════════════════════════════════════════════════════════════════
265  * STEP 2: Output Projection (attn_out @ Wo) + Residual
266  * ═══════════════════════════════════════════════════════════════════════ */
267 
268  for (int i = 0; i < embed_dim; i++) {
269  float sum = 0.0f;
270  const float *wo_row = wo + i * q_dim;
271  for (int j = 0; j < q_dim; j++) {
272  sum += wo_row[j] * attn_out[j];
273  }
274  hidden_after_attn[i] = sum + residual_1[i]; /* Residual add */
275  }
276 
277  /* ═══════════════════════════════════════════════════════════════════════
278  * STEP 3: RMSNorm
279  * ═══════════════════════════════════════════════════════════════════════ */
280 
281  float rms_scale = compute_rms_scale_internal(hidden_after_attn, embed_dim, eps);
282 
283 #ifdef __AVX2__
284  __m256 vscale = _mm256_set1_ps(rms_scale);
285  int i = 0;
286  for (; i + 7 < embed_dim; i += 8) {
287  __m256 vh = _mm256_loadu_ps(hidden_after_attn + i);
288  __m256 vw = _mm256_loadu_ps(rms_weight + i);
289  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
290  _mm256_storeu_ps(normed + i, vn);
291  }
292  for (; i < embed_dim; i++) {
293  normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
294  }
295 #else
296  for (int i = 0; i < embed_dim; i++) {
297  normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
298  }
299 #endif
300 
301  /* ═══════════════════════════════════════════════════════════════════════
302  * STEP 4: MLP Gate + Up projections (can be parallelized)
303  * ═══════════════════════════════════════════════════════════════════════ */
304 
305  /* Gate projection: gate_out = normed @ W_gate^T */
306  for (int i = 0; i < intermediate_dim; i++) {
307  float sum = 0.0f;
308  const float *wg_row = w_gate + i * embed_dim;
309  for (int j = 0; j < embed_dim; j++) {
310  sum += wg_row[j] * normed[j];
311  }
312  gate_out[i] = sum;
313  }
314 
315  /* Up projection: up_out = normed @ W_up^T */
316  for (int i = 0; i < intermediate_dim; i++) {
317  float sum = 0.0f;
318  const float *wu_row = w_up + i * embed_dim;
319  for (int j = 0; j < embed_dim; j++) {
320  sum += wu_row[j] * normed[j];
321  }
322  up_out[i] = sum;
323  }
324 
325  /* ═══════════════════════════════════════════════════════════════════════
326  * STEP 5: SwiGLU activation: silu(gate) * up
327  * ═══════════════════════════════════════════════════════════════════════ */
328 
329 #ifdef __AVX2__
330  i = 0;
331  for (; i + 7 < intermediate_dim; i += 8) {
332  __m256 vg = _mm256_loadu_ps(gate_out + i);
333  __m256 vu = _mm256_loadu_ps(up_out + i);
334  __m256 vsilu = silu_avx2(vg);
335  __m256 vswiglu = _mm256_mul_ps(vsilu, vu);
336  _mm256_storeu_ps(gate_out + i, vswiglu); /* Reuse gate_out buffer */
337  }
338  for (; i < intermediate_dim; i++) {
339  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
340  }
341 #else
342  for (int i = 0; i < intermediate_dim; i++) {
343  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
344  }
345 #endif
346 
347  /* ═══════════════════════════════════════════════════════════════════════
348  * STEP 6: Down projection + Final Residual
349  * ═══════════════════════════════════════════════════════════════════════ */
350 
351  for (int i = 0; i < embed_dim; i++) {
352  float sum = 0.0f;
353  const float *wd_row = w_down + i * intermediate_dim;
354  for (int j = 0; j < intermediate_dim; j++) {
355  sum += wd_row[j] * gate_out[j]; /* gate_out now holds SwiGLU output */
356  }
357  hidden_out[i] = sum + hidden_after_attn[i]; /* Final residual */
358  }
359 }
static float silu_scalar(float x)
static float compute_rms_scale_internal(const float *x, int n, float eps)
static void softmax_inplace(float *x, int n)
int32_t float * score
Definition: tokenizer.h:327

References compute_rms_scale_internal(), score, silu_scalar(), and softmax_inplace().

◆ attention_mlp_fused_q4k()

void attention_mlp_fused_q4k ( const float *  q,
const float *  k_cache,
const float *  v_cache,
int  seq_len,
int  num_heads,
int  num_kv_heads,
int  head_dim,
float  attn_scale,
const void *  wo,
const float *  residual_1,
const float *  rms_weight,
float  eps,
const void *  w_gate,
const void *  w_up,
const void *  w_down,
int  embed_dim,
int  intermediate_dim,
float *  hidden_out 
)

Definition at line 742 of file attention_mlp_fused.c.

774  {
775  const int heads_per_kv = num_heads / num_kv_heads;
776  const int q_dim = num_heads * head_dim;
777  const int kv_dim = num_kv_heads * head_dim;
778 
779  /* Stack buffers */
780  float attn_out[4096];
781  float hidden_after_attn[4096];
782  float normed[4096];
783  float mlp_out[4096];
784 
785  if (embed_dim > 4096) return;
786 
787  /* ═══════════════════════════════════════════════════════════════════════
788  * STEP 1: Multi-Head Attention (same as FP32 version)
789  * ═══════════════════════════════════════════════════════════════════════ */
790 
791  memset(attn_out, 0, q_dim * sizeof(float));
792 
793  for (int h = 0; h < num_heads; h++) {
794  int kv_h = h / heads_per_kv;
795  const float *q_head = q + h * head_dim;
796  float *out_head = attn_out + h * head_dim;
797 
798  float scores[8192];
799  if (seq_len > 8192) return;
800 
801  for (int t = 0; t < seq_len; t++) {
802  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
803  float score = 0.0f;
804  for (int d = 0; d < head_dim; d++) {
805  score += q_head[d] * k_t[d];
806  }
807  scores[t] = score * attn_scale;
808  }
809 
810  softmax_inplace(scores, seq_len);
811 
812  for (int t = 0; t < seq_len; t++) {
813  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
814  float w = scores[t];
815  for (int d = 0; d < head_dim; d++) {
816  out_head[d] += w * v_t[d];
817  }
818  }
819  }
820 
821  /* ═══════════════════════════════════════════════════════════════════════
822  * STEP 2: Output Projection (Q4_K) + Residual
823  * ═══════════════════════════════════════════════════════════════════════ */
824 
825  extern void gemv_q4_k(float *y, const void *W, const float *x, int M, int K);
826 
827  gemv_q4_k(hidden_after_attn, wo, attn_out, embed_dim, q_dim);
828 
829  /* Add residual */
830  for (int i = 0; i < embed_dim; i++) {
831  hidden_after_attn[i] += residual_1[i];
832  }
833 
834  /* ═══════════════════════════════════════════════════════════════════════
835  * STEP 3: RMSNorm (same as before)
836  * ═══════════════════════════════════════════════════════════════════════ */
837 
838  float rms_scale = compute_rms_scale_internal(hidden_after_attn, embed_dim, eps);
839 
840  for (int i = 0; i < embed_dim; i++) {
841  normed[i] = hidden_after_attn[i] * rms_weight[i] * rms_scale;
842  }
843 
844  /* ═══════════════════════════════════════════════════════════════════════
845  * STEP 4-6: MLP with Q4_K weights (inline implementation)
846  *
847  * gate_out = normed @ W_gate
848  * up_out = normed @ W_up
849  * swiglu = silu(gate_out) * up_out
850  * mlp_out = swiglu @ W_down
851  * ═══════════════════════════════════════════════════════════════════════ */
852 
853  float gate_out[16384];
854  float up_out[16384];
855 
856  if (intermediate_dim > 16384) return;
857 
858  /* Gate projection */
859  gemv_q4_k(gate_out, w_gate, normed, intermediate_dim, embed_dim);
860 
861  /* Up projection */
862  gemv_q4_k(up_out, w_up, normed, intermediate_dim, embed_dim);
863 
864  /* SwiGLU: silu(gate) * up */
865 #ifdef __AVX2__
866  int i = 0;
867  for (; i + 7 < intermediate_dim; i += 8) {
868  __m256 vg = _mm256_loadu_ps(gate_out + i);
869  __m256 vu = _mm256_loadu_ps(up_out + i);
870  __m256 vsilu = silu_avx2(vg);
871  __m256 vswiglu = _mm256_mul_ps(vsilu, vu);
872  _mm256_storeu_ps(gate_out + i, vswiglu);
873  }
874  for (; i < intermediate_dim; i++) {
875  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
876  }
877 #else
878  for (int i = 0; i < intermediate_dim; i++) {
879  gate_out[i] = silu_scalar(gate_out[i]) * up_out[i];
880  }
881 #endif
882 
883  /* Down projection */
884  gemv_q4_k(mlp_out, w_down, gate_out, embed_dim, intermediate_dim);
885 
886  /* Final residual add */
887  for (int i = 0; i < embed_dim; i++) {
888  hidden_out[i] = mlp_out[i] + hidden_after_attn[i];
889  }
890 }
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.

References compute_rms_scale_internal(), gemv_q4_k(), score, silu_scalar(), and softmax_inplace().

◆ attention_mlp_separate_fp32()

void attention_mlp_separate_fp32 ( const float *  q,
const float *  k_cache,
const float *  v_cache,
int  seq_len,
int  num_heads,
int  num_kv_heads,
int  head_dim,
float  attn_scale,
const float *  wo,
const float *  residual_1,
const float *  rms_weight,
float  eps,
const float *  w_gate,
const float *  w_up,
const float *  w_down,
int  embed_dim,
int  intermediate_dim,
float *  attn_out_buf,
float *  hidden_after_attn_buf,
float *  normed_buf,
float *  gate_buf,
float *  up_buf,
float *  mlp_out_buf,
float *  hidden_out 
)

Definition at line 1084 of file attention_mlp_fused.c.

1101  {
1102  /* This version writes all intermediates to the provided buffers,
1103  * simulating non-fused execution with DRAM traffic */
1104 
1105  const int heads_per_kv = num_heads / num_kv_heads;
1106  const int q_dim = num_heads * head_dim;
1107  const int kv_dim = num_kv_heads * head_dim;
1108 
1109  /* Step 1: Attention */
1110  memset(attn_out_buf, 0, q_dim * sizeof(float));
1111 
1112  /* Stack-allocated scores buffer (no malloc!) */
1113  float scores[8192]; /* Max seq_len supported */
1114  if (seq_len > 8192) return;
1115 
1116  for (int h = 0; h < num_heads; h++) {
1117  int kv_h = h / heads_per_kv;
1118  const float *q_head = q + h * head_dim;
1119  float *out_head = attn_out_buf + h * head_dim;
1120 
1121  for (int t = 0; t < seq_len; t++) {
1122  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
1123  float score = 0.0f;
1124  for (int d = 0; d < head_dim; d++) {
1125  score += q_head[d] * k_t[d];
1126  }
1127  scores[t] = score * attn_scale;
1128  }
1129 
1130  softmax_inplace(scores, seq_len);
1131 
1132  for (int t = 0; t < seq_len; t++) {
1133  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
1134  float w = scores[t];
1135  for (int d = 0; d < head_dim; d++) {
1136  out_head[d] += w * v_t[d];
1137  }
1138  }
1139  }
1140 
1141  /* Step 2: Output projection + residual -> DRAM write */
1142  for (int i = 0; i < embed_dim; i++) {
1143  float sum = 0.0f;
1144  const float *wo_row = wo + i * q_dim;
1145  for (int j = 0; j < q_dim; j++) {
1146  sum += wo_row[j] * attn_out_buf[j];
1147  }
1148  hidden_after_attn_buf[i] = sum + residual_1[i];
1149  }
1150 
1151  /* Step 3: RMSNorm -> DRAM write */
1152  float rms_scale = compute_rms_scale_internal(hidden_after_attn_buf, embed_dim, eps);
1153  for (int i = 0; i < embed_dim; i++) {
1154  normed_buf[i] = hidden_after_attn_buf[i] * rms_weight[i] * rms_scale;
1155  }
1156 
1157  /* Step 4: Gate projection -> DRAM write */
1158  for (int i = 0; i < intermediate_dim; i++) {
1159  float sum = 0.0f;
1160  const float *wg_row = w_gate + i * embed_dim;
1161  for (int j = 0; j < embed_dim; j++) {
1162  sum += wg_row[j] * normed_buf[j];
1163  }
1164  gate_buf[i] = sum;
1165  }
1166 
1167  /* Step 5: Up projection -> DRAM write */
1168  for (int i = 0; i < intermediate_dim; i++) {
1169  float sum = 0.0f;
1170  const float *wu_row = w_up + i * embed_dim;
1171  for (int j = 0; j < embed_dim; j++) {
1172  sum += wu_row[j] * normed_buf[j];
1173  }
1174  up_buf[i] = sum;
1175  }
1176 
1177  /* Step 6: SwiGLU (in-place in gate_buf) */
1178  for (int i = 0; i < intermediate_dim; i++) {
1179  gate_buf[i] = silu_scalar(gate_buf[i]) * up_buf[i];
1180  }
1181 
1182  /* Step 7: Down projection -> DRAM write */
1183  for (int i = 0; i < embed_dim; i++) {
1184  float sum = 0.0f;
1185  const float *wd_row = w_down + i * intermediate_dim;
1186  for (int j = 0; j < intermediate_dim; j++) {
1187  sum += wd_row[j] * gate_buf[j];
1188  }
1189  mlp_out_buf[i] = sum;
1190  }
1191 
1192  /* Step 8: Final residual */
1193  for (int i = 0; i < embed_dim; i++) {
1194  hidden_out[i] = mlp_out_buf[i] + hidden_after_attn_buf[i];
1195  }
1196 }

References compute_rms_scale_internal(), score, silu_scalar(), and softmax_inplace().

◆ compute_rms_scale_internal()

static float compute_rms_scale_internal ( const float *  x,
int  n,
float  eps 
)
inlinestatic

Definition at line 76 of file attention_mlp_fused.c.

76  {
77  float sum_sq = 0.0f;
78 
79 #ifdef __AVX2__
80  __m256 vsum = _mm256_setzero_ps();
81  int i = 0;
82  for (; i + 7 < n; i += 8) {
83  __m256 vx = _mm256_loadu_ps(x + i);
84  vsum = _mm256_fmadd_ps(vx, vx, vsum);
85  }
86  __m128 vlow = _mm256_castps256_ps128(vsum);
87  __m128 vhigh = _mm256_extractf128_ps(vsum, 1);
88  vlow = _mm_add_ps(vlow, vhigh);
89  vlow = _mm_hadd_ps(vlow, vlow);
90  vlow = _mm_hadd_ps(vlow, vlow);
91  sum_sq = _mm_cvtss_f32(vlow);
92  for (; i < n; i++) {
93  sum_sq += x[i] * x[i];
94  }
95 #else
96  for (int i = 0; i < n; i++) {
97  sum_sq += x[i] * x[i];
98  }
99 #endif
100 
101  float rms = sqrtf(sum_sq / (float)n + eps);
102  return 1.0f / rms;
103 }

Referenced by attention_mlp_fused_fp32(), attention_mlp_fused_q4k(), attention_mlp_separate_fp32(), mlp_fused_fp32_v2(), mlp_fused_fp32_v3(), and mlp_separate_fp32().

◆ layer_fused_attn_mlp_qkv_q4k()

void layer_fused_attn_mlp_qkv_q4k ( const float *  q,
const float *  k_cache,
const float *  v_cache,
int  seq_len,
float  attn_scale,
const void *  wo,
const float *  rms_weight_mlp,
const void *  w_gate,
const void *  w_up,
const void *  w_down,
const float *  rms_weight_attn,
const void *  wq_next,
const void *  wk_next,
const void *  wv_next,
const float *  residual_in,
int  embed_dim,
int  intermediate_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
float  eps,
float *  q_next,
float *  k_next,
float *  v_next,
float *  hidden_out 
)

Definition at line 901 of file attention_mlp_fused.c.

938  {
939  extern void gemv_q4_k(float *y, const void *W, const float *x, int M, int K);
940 
941  const int heads_per_kv = num_heads / num_kv_heads;
942  const int q_dim = num_heads * head_dim;
943  const int kv_dim = num_kv_heads * head_dim;
944 
945  /* All intermediate buffers on stack - stay in L1/L2
946  * hidden_out is the final output buffer - we write to it directly! */
947  float attn_out[4096];
948  float hidden_after_attn[4096];
949  float normed_mlp[4096];
950  float gate_out[16384];
951  float up_out[16384];
952  /* NOTE: No hidden_after_mlp buffer - we output directly to hidden_out */
953  float normed_attn[4096];
954 
955  if (embed_dim > 4096 || intermediate_dim > 16384) return;
956 
957  /* ═══════════════════════════════════════════════════════════════════════
958  * STEP 1: Multi-Head Attention
959  * ═══════════════════════════════════════════════════════════════════════ */
960 
961  memset(attn_out, 0, q_dim * sizeof(float));
962 
963  for (int h = 0; h < num_heads; h++) {
964  int kv_h = h / heads_per_kv;
965  const float *q_head = q + h * head_dim;
966  float *out_head = attn_out + h * head_dim;
967 
968  float scores[8192];
969  if (seq_len > 8192) return;
970 
971  for (int t = 0; t < seq_len; t++) {
972  const float *k_t = k_cache + t * kv_dim + kv_h * head_dim;
973  float score = 0.0f;
974  for (int d = 0; d < head_dim; d++) {
975  score += q_head[d] * k_t[d];
976  }
977  scores[t] = score * attn_scale;
978  }
979 
980  /* Softmax */
981  float max_score = scores[0];
982  for (int t = 1; t < seq_len; t++) {
983  if (scores[t] > max_score) max_score = scores[t];
984  }
985  float sum_exp = 0.0f;
986  for (int t = 0; t < seq_len; t++) {
987  scores[t] = expf(scores[t] - max_score);
988  sum_exp += scores[t];
989  }
990  float inv_sum = 1.0f / sum_exp;
991  for (int t = 0; t < seq_len; t++) {
992  scores[t] *= inv_sum;
993  }
994 
995  /* Weighted sum of V */
996  for (int t = 0; t < seq_len; t++) {
997  const float *v_t = v_cache + t * kv_dim + kv_h * head_dim;
998  float w = scores[t];
999  for (int d = 0; d < head_dim; d++) {
1000  out_head[d] += w * v_t[d];
1001  }
1002  }
1003  }
1004 
1005  /* ═══════════════════════════════════════════════════════════════════════
1006  * STEP 2: Output Projection (Q4_K) + Residual
1007  * ═══════════════════════════════════════════════════════════════════════ */
1008 
1009  gemv_q4_k(hidden_after_attn, wo, attn_out, embed_dim, q_dim);
1010 
1011  for (int i = 0; i < embed_dim; i++) {
1012  hidden_after_attn[i] += residual_in[i];
1013  }
1014 
1015  /* ═══════════════════════════════════════════════════════════════════════
1016  * STEP 3: RMSNorm (for MLP)
1017  * ═══════════════════════════════════════════════════════════════════════ */
1018 
1019  float sum_sq = 0.0f;
1020  for (int i = 0; i < embed_dim; i++) {
1021  sum_sq += hidden_after_attn[i] * hidden_after_attn[i];
1022  }
1023  float rms_scale = 1.0f / sqrtf(sum_sq / embed_dim + eps);
1024 
1025  for (int i = 0; i < embed_dim; i++) {
1026  normed_mlp[i] = hidden_after_attn[i] * rms_weight_mlp[i] * rms_scale;
1027  }
1028 
1029  /* ═══════════════════════════════════════════════════════════════════════
1030  * STEP 4-6: MLP (gate + up + SwiGLU + down)
1031  * ═══════════════════════════════════════════════════════════════════════ */
1032 
1033  gemv_q4_k(gate_out, w_gate, normed_mlp, intermediate_dim, embed_dim);
1034  gemv_q4_k(up_out, w_up, normed_mlp, intermediate_dim, embed_dim);
1035 
1036  /* SwiGLU: silu(gate) * up */
1037  for (int i = 0; i < intermediate_dim; i++) {
1038  float g = gate_out[i];
1039  float silu_g = g / (1.0f + expf(-g));
1040  gate_out[i] = silu_g * up_out[i];
1041  }
1042 
1043  /* Down projection - output DIRECTLY to hidden_out (no intermediate buffer!) */
1044  gemv_q4_k(hidden_out, w_down, gate_out, embed_dim, intermediate_dim);
1045 
1046  /* MLP residual - hidden_out now contains the final hidden state */
1047  for (int i = 0; i < embed_dim; i++) {
1048  hidden_out[i] += hidden_after_attn[i];
1049  }
1050 
1051  /* ═══════════════════════════════════════════════════════════════════════
1052  * STEP 7: RMSNorm (for NEXT layer's attention)
1053  * Read from hidden_out (already contains final hidden state)
1054  * ═══════════════════════════════════════════════════════════════════════ */
1055 
1056  sum_sq = 0.0f;
1057  for (int i = 0; i < embed_dim; i++) {
1058  sum_sq += hidden_out[i] * hidden_out[i];
1059  }
1060  rms_scale = 1.0f / sqrtf(sum_sq / embed_dim + eps);
1061 
1062  for (int i = 0; i < embed_dim; i++) {
1063  normed_attn[i] = hidden_out[i] * rms_weight_attn[i] * rms_scale;
1064  }
1065 
1066  /* ═══════════════════════════════════════════════════════════════════════
1067  * STEP 8: NEXT LAYER's Q, K, V Projections
1068  *
1069  * Q goes to caller (for attention computation)
1070  * K, V go to KV cache (DRAM write - this is intentional!)
1071  * ═══════════════════════════════════════════════════════════════════════ */
1072 
1073  gemv_q4_k(q_next, wq_next, normed_attn, q_dim, embed_dim);
1074  gemv_q4_k(k_next, wk_next, normed_attn, kv_dim, embed_dim);
1075  gemv_q4_k(v_next, wv_next, normed_attn, kv_dim, embed_dim);
1076 
1077  /* hidden_out already contains the final hidden state - no memcpy needed! */
1078 }

References gemv_q4_k(), and score.

◆ mlp_fused_fp32_v2()

void mlp_fused_fp32_v2 ( const float *  hidden_in,
const float *  rms_weight,
float  eps,
const float *  w_gate,
const float *  w_up,
const float *  w_down,
int  embed_dim,
int  intermediate_dim,
float *  hidden_out 
)

Definition at line 407 of file attention_mlp_fused.c.

426  {
427  /* Stack buffers - sized for typical models */
428  float normed[4096];
429  float swiglu[16384]; /* intermediate_dim */
430 
431  if (embed_dim > 4096 || intermediate_dim > 16384) {
432  return; /* TODO: handle larger models */
433  }
434 
435  /* ═══════════════════════════════════════════════════════════════════════
436  * STEP 1: RMSNorm (SIMD)
437  * ═══════════════════════════════════════════════════════════════════════ */
438 
439  float rms_scale = compute_rms_scale_internal(hidden_in, embed_dim, eps);
440 
441 #ifdef __AVX2__
442  __m256 vscale = _mm256_set1_ps(rms_scale);
443  int i = 0;
444  for (; i + 7 < embed_dim; i += 8) {
445  __m256 vh = _mm256_loadu_ps(hidden_in + i);
446  __m256 vw = _mm256_loadu_ps(rms_weight + i);
447  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
448  _mm256_storeu_ps(normed + i, vn);
449  }
450  for (; i < embed_dim; i++) {
451  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
452  }
453 #else
454  for (int i = 0; i < embed_dim; i++) {
455  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
456  }
457 #endif
458 
459  /* ═══════════════════════════════════════════════════════════════════════
460  * STEP 2: Gate + Up projections with TRUE FUSION + SwiGLU
461  *
462  * Key insight: Compute gate[i] and up[i] together, then immediately
463  * apply SwiGLU. This eliminates separate gate_out and up_out buffers.
464  * ═══════════════════════════════════════════════════════════════════════ */
465 
466 #ifdef __AVX2__
467  for (int j = 0; j < intermediate_dim; j++) {
468  /* Compute gate and up for output j using SIMD GEMV */
469  const float *wg_row = w_gate + j * embed_dim;
470  const float *wu_row = w_up + j * embed_dim;
471 
472  __m256 gate_acc = _mm256_setzero_ps();
473  __m256 up_acc = _mm256_setzero_ps();
474 
475  int k = 0;
476  for (; k + 7 < embed_dim; k += 8) {
477  __m256 vn = _mm256_loadu_ps(normed + k);
478  __m256 vwg = _mm256_loadu_ps(wg_row + k);
479  __m256 vwu = _mm256_loadu_ps(wu_row + k);
480 
481  gate_acc = _mm256_fmadd_ps(vwg, vn, gate_acc);
482  up_acc = _mm256_fmadd_ps(vwu, vn, up_acc);
483  }
484 
485  /* Horizontal sums */
486  __m128 glow = _mm256_castps256_ps128(gate_acc);
487  __m128 ghigh = _mm256_extractf128_ps(gate_acc, 1);
488  glow = _mm_add_ps(glow, ghigh);
489  __m128 gshuf = _mm_movehdup_ps(glow);
490  glow = _mm_add_ps(glow, gshuf);
491  gshuf = _mm_movehl_ps(gshuf, glow);
492  glow = _mm_add_ss(glow, gshuf);
493  float gate_val = _mm_cvtss_f32(glow);
494 
495  __m128 ulow = _mm256_castps256_ps128(up_acc);
496  __m128 uhigh = _mm256_extractf128_ps(up_acc, 1);
497  ulow = _mm_add_ps(ulow, uhigh);
498  __m128 ushuf = _mm_movehdup_ps(ulow);
499  ulow = _mm_add_ps(ulow, ushuf);
500  ushuf = _mm_movehl_ps(ushuf, ulow);
501  ulow = _mm_add_ss(ulow, ushuf);
502  float up_val = _mm_cvtss_f32(ulow);
503 
504  /* Remainder */
505  for (; k < embed_dim; k++) {
506  gate_val += wg_row[k] * normed[k];
507  up_val += wu_row[k] * normed[k];
508  }
509 
510  /* Fused SwiGLU: silu(gate) * up */
511  swiglu[j] = silu_scalar(gate_val) * up_val;
512  }
513 #else
514  for (int j = 0; j < intermediate_dim; j++) {
515  const float *wg_row = w_gate + j * embed_dim;
516  const float *wu_row = w_up + j * embed_dim;
517  float gate_val = 0.0f, up_val = 0.0f;
518 
519  for (int k = 0; k < embed_dim; k++) {
520  gate_val += wg_row[k] * normed[k];
521  up_val += wu_row[k] * normed[k];
522  }
523 
524  swiglu[j] = silu_scalar(gate_val) * up_val;
525  }
526 #endif
527 
528  /* ═══════════════════════════════════════════════════════════════════════
529  * STEP 3: Down projection + Residual (SIMD GEMV)
530  * ═══════════════════════════════════════════════════════════════════════ */
531 
532 #ifdef __AVX2__
533  for (int j = 0; j < embed_dim; j++) {
534  float sum = gemv_fp32_row_avx2(w_down + j * intermediate_dim, swiglu, intermediate_dim);
535  hidden_out[j] = sum + hidden_in[j]; /* Residual */
536  }
537 #else
538  for (int j = 0; j < embed_dim; j++) {
539  float sum = 0.0f;
540  const float *wd_row = w_down + j * intermediate_dim;
541  for (int k = 0; k < intermediate_dim; k++) {
542  sum += wd_row[k] * swiglu[k];
543  }
544  hidden_out[j] = sum + hidden_in[j];
545  }
546 #endif
547 }

References compute_rms_scale_internal(), and silu_scalar().

◆ mlp_fused_fp32_v3()

void mlp_fused_fp32_v3 ( const float *  hidden_in,
const float *  rms_weight,
float  eps,
const float *  w_gate,
const float *  w_up,
const float *  w_down,
int  embed_dim,
int  intermediate_dim,
float *  hidden_out 
)

Definition at line 564 of file attention_mlp_fused.c.

574  {
575  /* Stack buffers */
576  float normed[4096];
577  float gate_out[16384];
578  float swiglu[16384];
579 
580  if (embed_dim > 4096 || intermediate_dim > 16384) {
581  return;
582  }
583 
584  /* ═══════════════════════════════════════════════════════════════════════
585  * STEP 1: RMSNorm (SIMD)
586  * ═══════════════════════════════════════════════════════════════════════ */
587 
588  float rms_scale = compute_rms_scale_internal(hidden_in, embed_dim, eps);
589 
590 #ifdef __AVX2__
591  __m256 vscale = _mm256_set1_ps(rms_scale);
592  int i = 0;
593  for (; i + 7 < embed_dim; i += 8) {
594  __m256 vh = _mm256_loadu_ps(hidden_in + i);
595  __m256 vw = _mm256_loadu_ps(rms_weight + i);
596  __m256 vn = _mm256_mul_ps(_mm256_mul_ps(vh, vw), vscale);
597  _mm256_storeu_ps(normed + i, vn);
598  }
599  for (; i < embed_dim; i++) {
600  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
601  }
602 #else
603  for (int i = 0; i < embed_dim; i++) {
604  normed[i] = hidden_in[i] * rms_weight[i] * rms_scale;
605  }
606 #endif
607 
608  /* ═══════════════════════════════════════════════════════════════════════
609  * STEP 2: Gate projection (SIMD GEMV, sequential weight access)
610  * ═══════════════════════════════════════════════════════════════════════ */
611 
612 #ifdef __AVX2__
613  for (int j = 0; j < intermediate_dim; j++) {
614  gate_out[j] = gemv_fp32_row_avx2(w_gate + j * embed_dim, normed, embed_dim);
615  }
616 #else
617  for (int j = 0; j < intermediate_dim; j++) {
618  float sum = 0.0f;
619  const float *wg_row = w_gate + j * embed_dim;
620  for (int k = 0; k < embed_dim; k++) {
621  sum += wg_row[k] * normed[k];
622  }
623  gate_out[j] = sum;
624  }
625 #endif
626 
627  /* ═══════════════════════════════════════════════════════════════════════
628  * STEP 3: Up projection + FUSED SwiGLU (SIMD GEMV, sequential access)
629  *
630  * Key: compute up[j], then immediately apply SwiGLU with gate[j].
631  * This avoids storing the full up_out buffer.
632  * ═══════════════════════════════════════════════════════════════════════ */
633 
634 #ifdef __AVX2__
635  for (int j = 0; j < intermediate_dim; j++) {
636  float up_val = gemv_fp32_row_avx2(w_up + j * embed_dim, normed, embed_dim);
637  /* Fused SwiGLU: silu(gate) * up */
638  swiglu[j] = silu_scalar(gate_out[j]) * up_val;
639  }
640 #else
641  for (int j = 0; j < intermediate_dim; j++) {
642  float up_val = 0.0f;
643  const float *wu_row = w_up + j * embed_dim;
644  for (int k = 0; k < embed_dim; k++) {
645  up_val += wu_row[k] * normed[k];
646  }
647  swiglu[j] = silu_scalar(gate_out[j]) * up_val;
648  }
649 #endif
650 
651  /* ═══════════════════════════════════════════════════════════════════════
652  * STEP 4: Down projection + Residual (SIMD GEMV)
653  * ═══════════════════════════════════════════════════════════════════════ */
654 
655 #ifdef __AVX2__
656  for (int j = 0; j < embed_dim; j++) {
657  float sum = gemv_fp32_row_avx2(w_down + j * intermediate_dim, swiglu, intermediate_dim);
658  hidden_out[j] = sum + hidden_in[j];
659  }
660 #else
661  for (int j = 0; j < embed_dim; j++) {
662  float sum = 0.0f;
663  const float *wd_row = w_down + j * intermediate_dim;
664  for (int k = 0; k < intermediate_dim; k++) {
665  sum += wd_row[k] * swiglu[k];
666  }
667  hidden_out[j] = sum + hidden_in[j];
668  }
669 #endif
670 }

References compute_rms_scale_internal(), and silu_scalar().

◆ mlp_separate_fp32()

void mlp_separate_fp32 ( const float *  hidden_in,
const float *  rms_weight,
float  eps,
const float *  w_gate,
const float *  w_up,
const float *  w_down,
float *  normed_buf,
float *  gate_buf,
float *  up_buf,
int  embed_dim,
int  intermediate_dim,
float *  hidden_out 
)

Definition at line 679 of file attention_mlp_fused.c.

692  {
693  /* Step 1: RMSNorm */
694  float rms_scale = compute_rms_scale_internal(hidden_in, embed_dim, eps);
695  for (int i = 0; i < embed_dim; i++) {
696  normed_buf[i] = hidden_in[i] * rms_weight[i] * rms_scale;
697  }
698 
699  /* Step 2: Gate projection */
700  for (int j = 0; j < intermediate_dim; j++) {
701  float sum = 0.0f;
702  const float *wg_row = w_gate + j * embed_dim;
703  for (int k = 0; k < embed_dim; k++) {
704  sum += wg_row[k] * normed_buf[k];
705  }
706  gate_buf[j] = sum;
707  }
708 
709  /* Step 3: Up projection */
710  for (int j = 0; j < intermediate_dim; j++) {
711  float sum = 0.0f;
712  const float *wu_row = w_up + j * embed_dim;
713  for (int k = 0; k < embed_dim; k++) {
714  sum += wu_row[k] * normed_buf[k];
715  }
716  up_buf[j] = sum;
717  }
718 
719  /* Step 4: SwiGLU */
720  for (int j = 0; j < intermediate_dim; j++) {
721  gate_buf[j] = silu_scalar(gate_buf[j]) * up_buf[j];
722  }
723 
724  /* Step 5: Down projection + Residual */
725  for (int j = 0; j < embed_dim; j++) {
726  float sum = 0.0f;
727  const float *wd_row = w_down + j * intermediate_dim;
728  for (int k = 0; k < intermediate_dim; k++) {
729  sum += wd_row[k] * gate_buf[k];
730  }
731  hidden_out[j] = sum + hidden_in[j];
732  }
733 }

References compute_rms_scale_internal(), and silu_scalar().

◆ silu_scalar()

static float silu_scalar ( float  x)
inlinestatic

Definition at line 109 of file attention_mlp_fused.c.

109  {
110  return x / (1.0f + expf(-x));
111 }

Referenced by attention_mlp_fused_fp32(), attention_mlp_fused_q4k(), attention_mlp_separate_fp32(), mlp_fused_fp32_v2(), mlp_fused_fp32_v3(), and mlp_separate_fp32().

◆ softmax_inplace()

static void softmax_inplace ( float *  x,
int  n 
)
static

Definition at line 150 of file attention_mlp_fused.c.

150  {
151  float max_val = x[0];
152  for (int i = 1; i < n; i++) {
153  if (x[i] > max_val) max_val = x[i];
154  }
155 
156  float sum = 0.0f;
157  for (int i = 0; i < n; i++) {
158  x[i] = expf(x[i] - max_val);
159  sum += x[i];
160  }
161 
162  float inv_sum = 1.0f / sum;
163  for (int i = 0; i < n; i++) {
164  x[i] *= inv_sum;
165  }
166 }

Referenced by attention_mlp_fused_fp32(), attention_mlp_fused_q4k(), and attention_mlp_separate_fp32().