← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_decode_q5_0.c File Reference

Mega-fused attention decode with Q5_0 weights. More...

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <math.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

static void apply_rope_inline (float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int H, int KV, int AD)
 
void attention_forward_decode_head_major_gqa_flash (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 
static void gemv_q5_0_from_fp32 (float *out, const void *W_q5_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
 
static void gemv_q8_0_from_fp32 (float *out, const void *W_q8_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
 
void mega_fused_attention_decode_q5_0 (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch)
 Serial mega-fused attention decode kernel. More...
 
void mega_fused_attention_decode_q5_0_parallel_simd (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth)
 Parallel SIMD mega-fused attention decode kernel (threadpool-aware) More...
 
int mega_fused_attention_decode_scratch_size (int AE, int H, int KV, int AD)
 Calculate scratch buffer size needed for the kernel. More...
 
void quantize_row_q8_0 (const float *x, void *vy, int k)
 Quantize FP32 to Q8_0 format (scalar reference) More...
 
void rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd, int T, int D, int AD, float eps)
 
void vec_dot_q5_0_q8_0 (int n, float *s, const void *vx, const void *vy)
 Auto-dispatch quantized dot product Q5_0 x Q8_0. More...
 
void vec_dot_q8_0_q8_0 (int n, float *s, const void *vx, const void *vy)
 Auto-dispatch quantized dot product Q8_0 x Q8_0. More...
 

Detailed Description

Mega-fused attention decode with Q5_0 weights.

STATUS: Serial kernel complete and correct. Parallel variant is a prototype that requires threadpool barrier support (not yet available). The non-fused decode path (ck_parallel_decode.h) already parallelizes each GEMV via row-splitting, so this fused kernel is not on the critical path. It can be enabled once threadpool parallelization is resolved (see PARALLELIZATION NOTES below).

FUSION: Combines 9 operations to minimize memory traffic. All intermediate data stays in scratch buffer (L1/L2 cache).

Operations fused:

  1. RMSNorm
  2. Q projection (Q5_0) with bias
  3. K projection (Q5_0) with bias
  4. V projection (Q8_0) with bias
  5. RoPE application
  6. KV cache store
  7. Flash attention decode (GQA-aware)
  8. O projection (Q5_0) with bias
  9. Residual add

PARALLELIZATION NOTES: The parallel_simd variant below documents the intended threading model but cannot run with the current threadpool (single dispatch, no mid-dispatch barrier). Three approaches were evaluated:

(A) Multi-dispatch (RECOMMENDED): Break into 3 ck_threadpool_dispatch() calls per layer: Dispatch 1: Row-split Q proj across threads. Thread 0 also does RMSNorm, K/V proj, RoPE, KV store (small ops that fit within Q proj wall time). Dispatch 2: Split attention across heads (h_start..h_end per thread). Dispatch 3: Row-split O proj across threads. Thread 0 does residual add after its rows. Cost: ~1us total for 2 extra barrier round-trips (negligible vs ~100us GEMV). Intermediates stay in shared scratch — cache benefit preserved.

(B) Redundant compute (single dispatch, no barrier): All threads redundantly compute RMSNorm + K/V proj + RoPE (~4us wasted per thread). Avoids barrier but wastes cycles on small ops. Only viable if Q/O proj dominate (true for short contexts).

(C) Skip fusion, use existing parallel GEMV: The non-fused decode path already parallelizes each GEMV call via ck_parallel_decode.h. For decode (M=1), intermediates are small (~3.5KB), so DRAM bandwidth savings from fusion are minimal. This is the current production path.

TESTING: make test-mega-fused-parity # Numerical parity make test-mega-fused-speed # Performance benchmark

Definition in file mega_fused_attention_decode_q5_0.c.

Function Documentation

◆ apply_rope_inline()

static void apply_rope_inline ( float *  q,
float *  k,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  H,
int  KV,
int  AD 
)
inlinestatic

Definition at line 135 of file mega_fused_attention_decode_q5_0.c.

144 {
145  const int D = AD / 2;
146  const float *cos_row = &rope_cos[pos * D];
147  const float *sin_row = &rope_sin[pos * D];
148 
149  /* Q heads */
150  for (int h = 0; h < H; h++) {
151  float *q_head = &q[h * AD];
152  for (int d = 0; d < D; d++) {
153  float q0 = q_head[d];
154  float q1 = q_head[d + D];
155  q_head[d] = q0 * cos_row[d] - q1 * sin_row[d];
156  q_head[d + D] = q0 * sin_row[d] + q1 * cos_row[d];
157  }
158  }
159 
160  /* K heads */
161  for (int kv = 0; kv < KV; kv++) {
162  float *k_head = &k[kv * AD];
163  for (int d = 0; d < D; d++) {
164  float k0 = k_head[d];
165  float k1 = k_head[d + D];
166  k_head[d] = k0 * cos_row[d] - k1 * sin_row[d];
167  k_head[d + D] = k0 * sin_row[d] + k1 * cos_row[d];
168  }
169  }
170 }

Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ attention_forward_decode_head_major_gqa_flash()

void attention_forward_decode_head_major_gqa_flash ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Flash attention decode (single token attends to KV cache)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_decode

test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode

test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode

test_attention.py::TestAttentionForward::test_flash_decode

Single query token attends to kv_tokens in KV cache. Uses true flash attention from attention_flash_true.c.

After changes: make test && make llamacpp-parity-full

Definition at line 1467 of file attention_kernels.c.

1477 {
1478  if (!q_token || !k_cache || !v_cache || !out_token) {
1479  return;
1480  }
1481  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1482  return;
1483  }
1484  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1485  return;
1486  }
1487 
1488  const float scale = 1.0f / sqrtf((float)head_dim);
1489  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1490 
1491  for (int h = 0; h < num_heads; ++h) {
1492  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1493  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1494  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1495  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1496  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1497 
1498  attention_flash_decode(out_head,
1499  q_head,
1500  k_head,
1501  v_head,
1502  1,
1503  kv_tokens,
1504  1,
1505  aligned_head_dim,
1506  scale);
1507  }
1508 }
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.

Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ gemv_q5_0_from_fp32()

static void gemv_q5_0_from_fp32 ( float *  out,
const void *  W_q5_0,
const float *  x_fp32,
const float *  bias,
int  M,
int  K,
block_q8_0 x_q8_scratch 
)
inlinestatic

Definition at line 84 of file mega_fused_attention_decode_q5_0.c.

92 {
93  const block_q5_0 *w_blocks = (const block_q5_0 *)W_q5_0;
94  const int blocks_per_row = K / QK5_0;
95 
96  /* Quantize input to Q8_0 (reuse existing kernel, scratch buffer) */
97  quantize_row_q8_0(x_fp32, x_q8_scratch, K);
98 
99  /* Compute dot products using optimized kernel */
100  for (int row = 0; row < M; row++) {
101  float dot;
102  vec_dot_q5_0_q8_0(K, &dot, &w_blocks[row * blocks_per_row], x_q8_scratch);
103  out[row] = dot + (bias ? bias[row] : 0.0f);
104  }
105 }
#define QK5_0
Definition: ckernel_quant.h:67
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References QK5_0, quantize_row_q8_0(), and vec_dot_q5_0_q8_0().

Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ gemv_q8_0_from_fp32()

static void gemv_q8_0_from_fp32 ( float *  out,
const void *  W_q8_0,
const float *  x_fp32,
const float *  bias,
int  M,
int  K,
block_q8_0 x_q8_scratch 
)
inlinestatic

Definition at line 108 of file mega_fused_attention_decode_q5_0.c.

116 {
117  const block_q8_0 *w_blocks = (const block_q8_0 *)W_q8_0;
118  const int blocks_per_row = K / QK8_0;
119 
120  /* Quantize input to Q8_0 (reuse existing kernel, scratch buffer) */
121  quantize_row_q8_0(x_fp32, x_q8_scratch, K);
122 
123  /* Compute dot products */
124  for (int row = 0; row < M; row++) {
125  float dot;
126  vec_dot_q8_0_q8_0(K, &dot, &w_blocks[row * blocks_per_row], x_q8_scratch);
127  out[row] = dot + (bias ? bias[row] : 0.0f);
128  }
129 }
#define QK8_0
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.

References QK8_0, quantize_row_q8_0(), and vec_dot_q8_0_q8_0().

Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ mega_fused_attention_decode_q5_0()

void mega_fused_attention_decode_q5_0 ( float *  output,
const float *  input,
const float *  residual,
const void *  wq_q5_0,
const void *  wk_q5_0,
const void *  wv_q8_0,
const void *  wo_q5_0,
const float *  ln_gamma,
const float *  bq,
const float *  bk,
const float *  bv,
const float *  bo,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  cache_capacity,
float  eps,
void *  scratch 
)

Serial mega-fused attention decode kernel.

Parameters
outputOutput [AE] (final result, after residual add)
inputInput activation [AE]
residualResidual input for add [AE]
wq_q5_0Q projection weights [H*AD, AE] Q5_0
wk_q5_0K projection weights [KV*AD, AE] Q5_0
wv_q8_0V projection weights [KV*AD, AE] Q8_0
wo_q5_0O projection weights [AE, H*AD] Q5_0
ln_gammaRMSNorm gamma [AE]
bqQ bias [H*AD] or NULL
bkK bias [KV*AD] or NULL
bvV bias [KV*AD] or NULL
boO bias [AE] or NULL
kv_cache_kK cache [KV, max_T, AD]
kv_cache_vV cache [KV, max_T, AD]
rope_cosRoPE cos [max_T, D]
rope_sinRoPE sin [max_T, D]
posCurrent position (0-indexed)
embed_dimOriginal embedding dimension E
aligned_embed_dimAligned embedding dimension AE
num_headsNumber of query heads H
num_kv_headsNumber of key/value heads KV
head_dimHead dimension AD
aligned_head_dimAligned head dimension AAD
cache_capacityMaximum cache capacity max_T
epsRMSNorm epsilon
scratchScratch buffer (>= scratch_size bytes)

Definition at line 222 of file mega_fused_attention_decode_q5_0.c.

249 {
250  const int H = num_heads;
251  const int KV = num_kv_heads;
252  const int AD = head_dim;
253  const int AE = aligned_embed_dim;
254  (void)embed_dim; /* Unused but kept for API consistency */
255 
256  /* Parse scratch buffer - all allocations from scratch, no VLAs */
257  float *scratch_ptr = (float *)scratch;
258 
259  float *rmsnorm_out = scratch_ptr;
260  scratch_ptr += AE;
261 
262  float *rstd_scratch = scratch_ptr; /* For rmsnorm rstd output - avoids VLA */
263  scratch_ptr += AE;
264 
265  float *q = scratch_ptr;
266  scratch_ptr += H * AD;
267 
268  float *k = scratch_ptr;
269  scratch_ptr += KV * AD;
270 
271  float *v = scratch_ptr;
272  scratch_ptr += KV * AD;
273 
274  float *attn_out = scratch_ptr;
275  scratch_ptr += H * AD;
276 
277  block_q8_0 *x_q8_scratch = (block_q8_0 *)scratch_ptr;
278 
279  const int q_size = H * AD;
280  const int k_size = KV * AD;
281  const int v_size = KV * AD;
282 
283  /* ========================================================================
284  * STEP 1: RMSNorm
285  * Correct signature: rmsnorm_forward(in, gamma, out, rstd, T, D, AD, eps)
286  * T=1 (single token), D=AE (full embed dim for norm)
287  * ======================================================================== */
288  rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
289 
290  /* ========================================================================
291  * STEP 2-4: Q, K, V projections (fused with quantization)
292  * Use scratch buffer for quantized input
293  * ======================================================================== */
294  gemv_q5_0_from_fp32(q, wq_q5_0, rmsnorm_out, bq, q_size, AE, x_q8_scratch);
295  gemv_q5_0_from_fp32(k, wk_q5_0, rmsnorm_out, bk, k_size, AE, x_q8_scratch);
296  gemv_q8_0_from_fp32(v, wv_q8_0, rmsnorm_out, bv, v_size, AE, x_q8_scratch);
297 
298  /* ========================================================================
299  * STEP 5: Apply RoPE
300  * ======================================================================== */
301  apply_rope_inline(q, k, rope_cos, rope_sin, pos, H, KV, AD);
302 
303  /* ========================================================================
304  * STEP 6: Store K and V to cache
305  * Cache layout: [KV, cache_capacity, AD]
306  * ======================================================================== */
307  const size_t kv_stride = (size_t)cache_capacity * AD;
308  for (int kv = 0; kv < KV; kv++) {
309  float *k_cache = &kv_cache_k[kv * kv_stride];
310  float *v_cache = &kv_cache_v[kv * kv_stride];
311  const float *k_src = &k[kv * AD];
312  const float *v_src = &v[kv * AD];
313  const int offset = pos * AD;
314  for (int d = 0; d < AD; d++) {
315  k_cache[offset + d] = k_src[d];
316  v_cache[offset + d] = v_src[d];
317  }
318  }
319 
320  /* ========================================================================
321  * STEP 7: Flash attention decode (GQA-aware variant)
322  * attention_forward_decode_head_major_gqa_flash handles H != KV correctly
323  * It maps each of H heads to one of KV KV heads via: kv_head = h * KV / H
324  * ======================================================================== */
326  q, kv_cache_k, kv_cache_v,
327  attn_out, H, KV, pos + 1, cache_capacity, AD, aligned_head_dim);
328 
329  /* ========================================================================
330  * STEP 8: O projection (Q5_0 weights) with bias and residual add
331  *
332  * attn_out layout: [H * AD] flattened
333  * wo_q5_0 layout: [AE, H*AD] - row e has H*AD input features
334  *
335  * O projection: output[e] = dot(wo[e], attn_out) + bias[e] + residual[e]
336  *
337  * Using vec_dot_q5_0_q8_0 for efficient quantized dot product.
338  * ======================================================================== */
339 
340  /* Quantize attention output to Q8_0 for GEMV */
341  quantize_row_q8_0(attn_out, x_q8_scratch, H * AD);
342 
343  const block_q5_0 *wo = (const block_q5_0 *)wo_q5_0;
344  const int blocks_per_row = (H * AD) / QK5_0;
345 
346  for (int e = 0; e < AE; e++) {
347  float dot;
348  vec_dot_q5_0_q8_0(H * AD, &dot, &wo[e * blocks_per_row], x_q8_scratch);
349  output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
350  }
351 }
static void gemv_q5_0_from_fp32(float *out, const void *W_q5_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)
static void apply_rope_inline(float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int H, int KV, int AD)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd, int T, int D, int AD, float eps)
static void gemv_q8_0_from_fp32(float *out, const void *W_q8_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch)

References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().

◆ mega_fused_attention_decode_q5_0_parallel_simd()

void mega_fused_attention_decode_q5_0_parallel_simd ( float *  output,
const float *  input,
const float *  residual,
const void *  wq_q5_0,
const void *  wk_q5_0,
const void *  wv_q8_0,
const void *  wo_q5_0,
const float *  ln_gamma,
const float *  bq,
const float *  bk,
const float *  bv,
const float *  bo,
float *  kv_cache_k,
float *  kv_cache_v,
const float *  rope_cos,
const float *  rope_sin,
int  pos,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  cache_capacity,
float  eps,
void *  scratch,
int  ith,
int  nth 
)

Parallel SIMD mega-fused attention decode kernel (threadpool-aware)

Parallelizes across attention heads using (ith, nth) pattern. Each thread processes a subset of heads.

IMPORTANT: Caller must ensure barrier sync between phases: Phase 1 (ith==0 only): RMSNorm, Q/K/V projection, RoPE, KV cache store – BARRIER – Phase 2 (all threads): Attention for assigned heads – BARRIER – Phase 3 (ith==0 only): O projection and residual add

Parameters
ithThread index (0 to nth-1)
nthTotal number of threads (other parameters same as serial version)

Definition at line 367 of file mega_fused_attention_decode_q5_0.c.

396 {
397  const int H = num_heads;
398  const int KV = num_kv_heads;
399  const int AD = head_dim;
400  const int AE = aligned_embed_dim;
401  (void)embed_dim;
402 
403  /* Each thread handles a subset of heads */
404  const int heads_per_thread = (H + nth - 1) / nth;
405  const int h_start = ith * heads_per_thread;
406  const int h_end = (h_start + heads_per_thread < H) ? h_start + heads_per_thread : H;
407  const int my_heads = h_end - h_start;
408 
409  if (h_start >= H) return;
410 
411  /* Parse scratch buffer (shared across threads) */
412  float *scratch_ptr = (float *)scratch;
413 
414  float *rmsnorm_out = scratch_ptr;
415  scratch_ptr += AE;
416 
417  float *rstd_scratch = scratch_ptr;
418  scratch_ptr += AE;
419 
420  float *q = scratch_ptr;
421  scratch_ptr += H * AD;
422 
423  float *k = scratch_ptr;
424  scratch_ptr += KV * AD;
425 
426  float *v = scratch_ptr;
427  scratch_ptr += KV * AD;
428 
429  float *attn_out = scratch_ptr;
430  scratch_ptr += H * AD;
431 
432  block_q8_0 *x_q8_scratch = (block_q8_0 *)scratch_ptr;
433 
434  /* ========================================================================
435  * PHASE 1: Only thread 0 does RMSNorm and K/V projections
436  * These are shared across all heads.
437  * CALLER MUST BARRIER AFTER THIS PHASE.
438  * ======================================================================== */
439  if (ith == 0) {
440  rmsnorm_forward(input, ln_gamma, rmsnorm_out, rstd_scratch, 1, AE, AD, eps);
441 
442  gemv_q5_0_from_fp32(q, wq_q5_0, rmsnorm_out, bq, H * AD, AE, x_q8_scratch);
443  gemv_q5_0_from_fp32(k, wk_q5_0, rmsnorm_out, bk, KV * AD, AE, x_q8_scratch);
444  gemv_q8_0_from_fp32(v, wv_q8_0, rmsnorm_out, bv, KV * AD, AE, x_q8_scratch);
445 
446  apply_rope_inline(q, k, rope_cos, rope_sin, pos, H, KV, AD);
447 
448  /* Store K/V to cache */
449  const size_t kv_stride = (size_t)cache_capacity * AD;
450  for (int kv_idx = 0; kv_idx < KV; kv_idx++) {
451  float *k_cache = &kv_cache_k[kv_idx * kv_stride];
452  float *v_cache = &kv_cache_v[kv_idx * kv_stride];
453  const int offset = pos * AD;
454  for (int d = 0; d < AD; d++) {
455  k_cache[offset + d] = k[kv_idx * AD + d];
456  v_cache[offset + d] = v[kv_idx * AD + d];
457  }
458  }
459  }
460 
461  /* ========================================================================
462  * CALLER MUST BARRIER HERE
463  * All threads need to wait for thread 0 to finish projections
464  * ======================================================================== */
465 
466  /* ========================================================================
467  * PHASE 2: Each thread does attention for its heads only
468  * attention_forward_decode_head_major_gqa_flash expects:
469  * - q_token: pointer to start of Q for these heads
470  * - out_token: pointer to start of output for these heads
471  * - num_heads: number of heads THIS THREAD is processing
472  * ======================================================================== */
473  if (my_heads > 0) {
475  &q[h_start * AD], /* Q for this thread's heads */
476  kv_cache_k, kv_cache_v,
477  &attn_out[h_start * AD], /* Output for this thread's heads */
478  my_heads, /* Only my_heads, not H */
479  KV, /* Still need all KV heads for GQA */
480  pos + 1, cache_capacity, AD, aligned_head_dim);
481  }
482 
483  /* ========================================================================
484  * CALLER MUST BARRIER HERE
485  * Thread 0 needs all threads to finish attention before O projection
486  * ======================================================================== */
487 
488  /* ========================================================================
489  * PHASE 3: Thread 0 does O projection and residual add
490  * ======================================================================== */
491  if (ith == 0) {
492  /* Quantize full attention output for O projection */
493  quantize_row_q8_0(attn_out, x_q8_scratch, H * AD);
494 
495  const block_q5_0 *wo = (const block_q5_0 *)wo_q5_0;
496  const int blocks_per_row = (H * AD) / QK5_0;
497 
498  for (int e = 0; e < AE; e++) {
499  float dot;
500  vec_dot_q5_0_q8_0(H * AD, &dot, &wo[e * blocks_per_row], x_q8_scratch);
501  output[e] = dot + (bo ? bo[e] : 0.0f) + residual[e];
502  }
503  }
504 }

References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().

◆ mega_fused_attention_decode_scratch_size()

int mega_fused_attention_decode_scratch_size ( int  AE,
int  H,
int  KV,
int  AD 
)

Calculate scratch buffer size needed for the kernel.

Parameters
AEAligned embedding dimension (multiple of 64)
HNumber of query heads
KVNumber of key/value heads
ADHead dimension
Returns
Size in bytes needed for scratch buffer

Definition at line 176 of file mega_fused_attention_decode_q5_0.c.

176  {
177  /* Need: 1x AE for RMSNorm output
178  1x AE for RMSNorm rstd (avoid VLA)
179  1x H*AD for Q
180  1x KV*AD for K
181  1x KV*AD for V
182  1x H*AD for attention output
183  1x max(AE, H*AD)/QK8_0 * sizeof(block_q8_0) for GEMV scratch
184  */
185  int max_input_dim = (AE > H * AD) ? AE : H * AD;
186  int q8_blocks = (max_input_dim + QK8_0 - 1) / QK8_0;
187  return (int)(sizeof(float) * (AE + AE + H * AD + 2 * KV * AD + H * AD)
188  + q8_blocks * sizeof(block_q8_0));
189 }

References QK8_0.

◆ quantize_row_q8_0()

void quantize_row_q8_0 ( const float *  x,
void *  vy,
int  k 
)

Quantize FP32 to Q8_0 format (scalar reference)

Parameters
xInput FP32 values
vyOutput Q8_0 blocks
kNumber of elements (must be multiple of 32)

Definition at line 59 of file gemm_kernels_q8_0.c.

60 {
61  block_q8_0 *y = (block_q8_0 *)vy;
62  const int nb = k / QK8_0; /* QK8_0 = 32 */
63 
64 #if defined(__AVX__)
65  const __m256 sign_bit = _mm256_set1_ps(-0.0f);
66  const __m256 v_half = _mm256_set1_ps(0.5f);
67  const __m256 v_min = _mm256_set1_ps(-127.0f);
68  const __m256 v_max = _mm256_set1_ps(127.0f);
69 
70  for (int i = 0; i < nb; i++) {
71  __m256 v0 = _mm256_loadu_ps(x + 0);
72  __m256 v1 = _mm256_loadu_ps(x + 8);
73  __m256 v2 = _mm256_loadu_ps(x + 16);
74  __m256 v3 = _mm256_loadu_ps(x + 24);
75  x += QK8_0;
76 
77  __m256 max_abs = _mm256_andnot_ps(sign_bit, v0);
78  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v1));
79  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v2));
80  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v3));
81 
82  __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
83  _mm256_castps256_ps128(max_abs));
84  max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
85  max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
86  const float max_scalar = _mm_cvtss_f32(max4);
87 
88  const float d = max_scalar / 127.0f;
89  const float id = max_scalar != 0.0f ? 127.0f / max_scalar : 0.0f;
90  y[i].d = CK_FP32_TO_FP16(d);
91 
92  const __m256 mul = _mm256_set1_ps(id);
93  v0 = _mm256_mul_ps(v0, mul);
94  v1 = _mm256_mul_ps(v1, mul);
95  v2 = _mm256_mul_ps(v2, mul);
96  v3 = _mm256_mul_ps(v3, mul);
97 
98  v0 = _mm256_min_ps(_mm256_max_ps(v0, v_min), v_max);
99  v1 = _mm256_min_ps(_mm256_max_ps(v1, v_min), v_max);
100  v2 = _mm256_min_ps(_mm256_max_ps(v2, v_min), v_max);
101  v3 = _mm256_min_ps(_mm256_max_ps(v3, v_min), v_max);
102 
103  /* Round half away from zero to match the scalar path */
104  v0 = _mm256_add_ps(v0, _mm256_or_ps(_mm256_and_ps(v0, sign_bit), v_half));
105  v1 = _mm256_add_ps(v1, _mm256_or_ps(_mm256_and_ps(v1, sign_bit), v_half));
106  v2 = _mm256_add_ps(v2, _mm256_or_ps(_mm256_and_ps(v2, sign_bit), v_half));
107  v3 = _mm256_add_ps(v3, _mm256_or_ps(_mm256_and_ps(v3, sign_bit), v_half));
108 
109  __m256i i0 = _mm256_cvttps_epi32(v0);
110  __m256i i1 = _mm256_cvttps_epi32(v1);
111  __m256i i2 = _mm256_cvttps_epi32(v2);
112  __m256i i3 = _mm256_cvttps_epi32(v3);
113 
114 #if defined(__AVX2__)
115  i0 = _mm256_packs_epi32(i0, i1);
116  i2 = _mm256_packs_epi32(i2, i3);
117  i0 = _mm256_packs_epi16(i0, i2);
118 
119  const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
120  i0 = _mm256_permutevar8x32_epi32(i0, perm);
121  _mm256_storeu_si256((__m256i *)y[i].qs, i0);
122 #else
123  __m128i ni0 = _mm256_castsi256_si128(i0);
124  __m128i ni1 = _mm256_extractf128_si256(i0, 1);
125  __m128i ni2 = _mm256_castsi256_si128(i1);
126  __m128i ni3 = _mm256_extractf128_si256(i1, 1);
127  __m128i ni4 = _mm256_castsi256_si128(i2);
128  __m128i ni5 = _mm256_extractf128_si256(i2, 1);
129  __m128i ni6 = _mm256_castsi256_si128(i3);
130  __m128i ni7 = _mm256_extractf128_si256(i3, 1);
131 
132  ni0 = _mm_packs_epi32(ni0, ni1);
133  ni2 = _mm_packs_epi32(ni2, ni3);
134  ni4 = _mm_packs_epi32(ni4, ni5);
135  ni6 = _mm_packs_epi32(ni6, ni7);
136 
137  ni0 = _mm_packs_epi16(ni0, ni2);
138  ni4 = _mm_packs_epi16(ni4, ni6);
139 
140  _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
141  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
142 #endif
143  }
144 #else
145  for (int i = 0; i < nb; i++) {
146  const float *xb = x + i * QK8_0;
147 
148  /* Find max absolute value in block */
149  float amax = 0.0f;
150  for (int j = 0; j < QK8_0; j++) {
151  float av = xb[j] >= 0 ? xb[j] : -xb[j];
152  if (av > amax) amax = av;
153  }
154 
155  /* Compute scale: d = max / 127 */
156  float d = amax / 127.0f;
157  float id = d != 0.0f ? 127.0f / amax : 0.0f;
158 
159  /* Store scale as FP16 */
160  y[i].d = CK_FP32_TO_FP16(d);
161 
162  /* Quantize values */
163  for (int j = 0; j < QK8_0; j++) {
164  float v = xb[j] * id;
165  /* Round to nearest int and clamp to [-127, 127] */
166  int q = (int)(v + (v >= 0 ? 0.5f : -0.5f));
167  if (q > 127) q = 127;
168  if (q < -127) q = -127;
169  y[i].qs[j] = (int8_t)q;
170  }
171  }
172 #endif
173 }
#define CK_FP32_TO_FP16(x)
int8_t qs[32]
int32_t id
Definition: tokenizer.h:315

Referenced by gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ rmsnorm_forward()

void rmsnorm_forward ( const float *  input,
const float *  gamma,
float *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

RMSNorm forward pass

Test:

test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens

test_rmsnorm.py::TestRMSNormForward::test_fp32_single

test_rmsnorm.py::TestRMSNormForward::test_perf_rolled

test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat

test_parity.py::test_rmsnorm_parity

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

After changes: make test && make llamacpp-parity-full

Definition at line 50 of file rmsnorm_kernels.c.

58 {
59  int T = tokens;
60  int D = d_model;
61  int aligned = aligned_embed_dim;
62 
63  for (int t = 0; t < T; ++t) {
64  const float *x = input + (size_t)t * aligned;
65  float *y = output + (size_t)t * aligned;
66 
67 #if defined(__AVX512F__)
68  // AVX-512: Process 16 floats at a time
69  __m512 sum_sq_vec = _mm512_setzero_ps();
70  int d = 0;
71 
72  // Vectorized sum of squares
73  for (; d + 16 <= D; d += 16) {
74  __m512 xv = _mm512_loadu_ps(&x[d]);
75  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
76  }
77  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
78 
79  // Handle remaining elements
80  for (; d < D; ++d) {
81  sum_sq += x[d] * x[d];
82  }
83 
84  float mean_sq = sum_sq / (float)D;
85  float rstd = 1.0f / sqrtf(mean_sq + eps);
86  if (rstd_cache) {
87  rstd_cache[t] = rstd;
88  }
89 
90  // Apply normalization and scale (vectorized)
91  __m512 rstd_vec = _mm512_set1_ps(rstd);
92  d = 0;
93  for (; d + 16 <= D; d += 16) {
94  __m512 xv = _mm512_loadu_ps(&x[d]);
95  __m512 gv = _mm512_loadu_ps(&gamma[d]);
96  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97  __m512 yv = _mm512_mul_ps(x_hat, gv);
98  _mm512_storeu_ps(&y[d], yv);
99  }
100  // Handle remaining elements
101  for (; d < D; ++d) {
102  y[d] = x[d] * rstd * gamma[d];
103  }
104 
105 #elif defined(__AVX__)
106  // AVX: Process 8 floats at a time
107  __m256 sum_sq_vec = _mm256_setzero_ps();
108  int d = 0;
109 
110  // Vectorized sum of squares (no FMA in AVX1, use mul + add)
111  for (; d + 8 <= D; d += 8) {
112  __m256 xv = _mm256_loadu_ps(&x[d]);
113  __m256 xv_sq = _mm256_mul_ps(xv, xv);
114  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
115  }
116  float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
117 
118  // Handle remaining elements
119  for (; d < D; ++d) {
120  sum_sq += x[d] * x[d];
121  }
122 
123  float mean_sq = sum_sq / (float)D;
124  float rstd = 1.0f / sqrtf(mean_sq + eps);
125  if (rstd_cache) {
126  rstd_cache[t] = rstd;
127  }
128 
129  // Apply normalization and scale (vectorized)
130  __m256 rstd_vec = _mm256_set1_ps(rstd);
131  d = 0;
132  for (; d + 8 <= D; d += 8) {
133  __m256 xv = _mm256_loadu_ps(&x[d]);
134  __m256 gv = _mm256_loadu_ps(&gamma[d]);
135  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136  __m256 yv = _mm256_mul_ps(x_hat, gv);
137  _mm256_storeu_ps(&y[d], yv);
138  }
139  // Handle remaining elements
140  for (; d < D; ++d) {
141  y[d] = x[d] * rstd * gamma[d];
142  }
143 
144 #else
145  // Scalar fallback
146  double sum_sq = 0.0;
147  for (int d = 0; d < D; ++d) {
148  double v = (double)x[d];
149  sum_sq += v * v;
150  }
151  double mean_sq = sum_sq / (double)D;
152  double r = sqrt(mean_sq + (double)eps);
153  float rstd = (float)(1.0 / r);
154  if (rstd_cache) {
155  rstd_cache[t] = rstd;
156  }
157 
158  // Apply normalization and scale
159  for (int d = 0; d < D; ++d) {
160  float x_hat = x[d] * rstd;
161  y[d] = x_hat * gamma[d];
162  }
163 #endif
164 
165  // Zero padding (if any)
166  for (int d = D; d < aligned; ++d) {
167  y[d] = 0.0f;
168  }
169  }
170 }

Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ vec_dot_q5_0_q8_0()

void vec_dot_q5_0_q8_0 ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Auto-dispatch quantized dot product Q5_0 x Q8_0.

Dispatch priority:

  1. AVX512 (best performance on modern Intel/AMD)
  2. AVX (256-bit float ops, works on Sandy/Ivy Bridge and newer)
  3. SSSE3 (128-bit fallback)
  4. Reference scalar (last resort)

Definition at line 1498 of file gemm_kernels_q5_0.c.

1499 {
1500 #if defined(__AVX512F__)
1501  vec_dot_q5_0_q8_0_avx512(n, s, vx, vy);
1502 #elif defined(__AVX__)
1503  /* AVX for 256-bit float ops (works on Ivy Bridge and newer) */
1504  vec_dot_q5_0_q8_0_avx(n, s, vx, vy);
1505 #elif defined(__SSSE3__)
1506  /* SSSE3 - most efficient on older CPUs */
1507  vec_dot_q5_0_q8_0_sse(n, s, vx, vy);
1508 #else
1509  vec_dot_q5_0_q8_0_ref(n, s, vx, vy);
1510 #endif
1511 }
void vec_dot_q5_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)

Referenced by gemv_q5_0_from_fp32(), mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().

◆ vec_dot_q8_0_q8_0()

void vec_dot_q8_0_q8_0 ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Auto-dispatch quantized dot product Q8_0 x Q8_0.

Definition at line 1013 of file gemm_kernels_q8_0.c.

1014 {
1015 #ifdef __AVX512F__
1016  vec_dot_q8_0_q8_0_avx512(n, s, vx, vy);
1017 #elif defined(__AVX__)
1018  vec_dot_q8_0_q8_0_avx(n, s, vx, vy);
1019 #elif defined(__SSE4_1__)
1020  vec_dot_q8_0_q8_0_sse(n, s, vx, vy);
1021 #else
1022  vec_dot_q8_0_q8_0_ref(n, s, vx, vy);
1023 #endif
1024 }
void vec_dot_q8_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)

Referenced by gemv_q8_0_from_fp32().