Mega-fused attention decode with Q5_0 weights. More...
#include <stdio.h>#include <stdlib.h>#include <string.h>#include <stdint.h>#include <math.h>#include "ckernel_quant.h"Go to the source code of this file.
Functions | |
| static void | apply_rope_inline (float *q, float *k, const float *rope_cos, const float *rope_sin, int pos, int H, int KV, int AD) |
| void | attention_forward_decode_head_major_gqa_flash (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim) |
| static void | gemv_q5_0_from_fp32 (float *out, const void *W_q5_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch) |
| static void | gemv_q8_0_from_fp32 (float *out, const void *W_q8_0, const float *x_fp32, const float *bias, int M, int K, block_q8_0 *x_q8_scratch) |
| void | mega_fused_attention_decode_q5_0 (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch) |
| Serial mega-fused attention decode kernel. More... | |
| void | mega_fused_attention_decode_q5_0_parallel_simd (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth) |
| Parallel SIMD mega-fused attention decode kernel (threadpool-aware) More... | |
| int | mega_fused_attention_decode_scratch_size (int AE, int H, int KV, int AD) |
| Calculate scratch buffer size needed for the kernel. More... | |
| void | quantize_row_q8_0 (const float *x, void *vy, int k) |
| Quantize FP32 to Q8_0 format (scalar reference) More... | |
| void | rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd, int T, int D, int AD, float eps) |
| void | vec_dot_q5_0_q8_0 (int n, float *s, const void *vx, const void *vy) |
| Auto-dispatch quantized dot product Q5_0 x Q8_0. More... | |
| void | vec_dot_q8_0_q8_0 (int n, float *s, const void *vx, const void *vy) |
| Auto-dispatch quantized dot product Q8_0 x Q8_0. More... | |
Mega-fused attention decode with Q5_0 weights.
STATUS: Serial kernel complete and correct. Parallel variant is a prototype that requires threadpool barrier support (not yet available). The non-fused decode path (ck_parallel_decode.h) already parallelizes each GEMV via row-splitting, so this fused kernel is not on the critical path. It can be enabled once threadpool parallelization is resolved (see PARALLELIZATION NOTES below).
FUSION: Combines 9 operations to minimize memory traffic. All intermediate data stays in scratch buffer (L1/L2 cache).
Operations fused:
PARALLELIZATION NOTES: The parallel_simd variant below documents the intended threading model but cannot run with the current threadpool (single dispatch, no mid-dispatch barrier). Three approaches were evaluated:
(A) Multi-dispatch (RECOMMENDED): Break into 3 ck_threadpool_dispatch() calls per layer: Dispatch 1: Row-split Q proj across threads. Thread 0 also does RMSNorm, K/V proj, RoPE, KV store (small ops that fit within Q proj wall time). Dispatch 2: Split attention across heads (h_start..h_end per thread). Dispatch 3: Row-split O proj across threads. Thread 0 does residual add after its rows. Cost: ~1us total for 2 extra barrier round-trips (negligible vs ~100us GEMV). Intermediates stay in shared scratch — cache benefit preserved.
(B) Redundant compute (single dispatch, no barrier): All threads redundantly compute RMSNorm + K/V proj + RoPE (~4us wasted per thread). Avoids barrier but wastes cycles on small ops. Only viable if Q/O proj dominate (true for short contexts).
(C) Skip fusion, use existing parallel GEMV: The non-fused decode path already parallelizes each GEMV call via ck_parallel_decode.h. For decode (M=1), intermediates are small (~3.5KB), so DRAM bandwidth savings from fusion are minimal. This is the current production path.
TESTING: make test-mega-fused-parity # Numerical parity make test-mega-fused-speed # Performance benchmark
Definition in file mega_fused_attention_decode_q5_0.c.
|
inlinestatic |
Definition at line 135 of file mega_fused_attention_decode_q5_0.c.
Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
| void attention_forward_decode_head_major_gqa_flash | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Flash attention decode (single token attends to KV cache)
test_flash_attention.py::TestFlashAttention::test_flash_decode
test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode
test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode
test_attention.py::TestAttentionForward::test_flash_decode
Single query token attends to kv_tokens in KV cache. Uses true flash attention from attention_flash_true.c.
After changes: make test && make llamacpp-parity-full
Definition at line 1467 of file attention_kernels.c.
Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
|
inlinestatic |
Definition at line 84 of file mega_fused_attention_decode_q5_0.c.
References QK5_0, quantize_row_q8_0(), and vec_dot_q5_0_q8_0().
Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
|
inlinestatic |
Definition at line 108 of file mega_fused_attention_decode_q5_0.c.
References QK8_0, quantize_row_q8_0(), and vec_dot_q8_0_q8_0().
Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
| void mega_fused_attention_decode_q5_0 | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const void * | wq_q5_0, | ||
| const void * | wk_q5_0, | ||
| const void * | wv_q8_0, | ||
| const void * | wo_q5_0, | ||
| const float * | ln_gamma, | ||
| const float * | bq, | ||
| const float * | bk, | ||
| const float * | bv, | ||
| const float * | bo, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | cache_capacity, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Serial mega-fused attention decode kernel.
| output | Output [AE] (final result, after residual add) |
| input | Input activation [AE] |
| residual | Residual input for add [AE] |
| wq_q5_0 | Q projection weights [H*AD, AE] Q5_0 |
| wk_q5_0 | K projection weights [KV*AD, AE] Q5_0 |
| wv_q8_0 | V projection weights [KV*AD, AE] Q8_0 |
| wo_q5_0 | O projection weights [AE, H*AD] Q5_0 |
| ln_gamma | RMSNorm gamma [AE] |
| bq | Q bias [H*AD] or NULL |
| bk | K bias [KV*AD] or NULL |
| bv | V bias [KV*AD] or NULL |
| bo | O bias [AE] or NULL |
| kv_cache_k | K cache [KV, max_T, AD] |
| kv_cache_v | V cache [KV, max_T, AD] |
| rope_cos | RoPE cos [max_T, D] |
| rope_sin | RoPE sin [max_T, D] |
| pos | Current position (0-indexed) |
| embed_dim | Original embedding dimension E |
| aligned_embed_dim | Aligned embedding dimension AE |
| num_heads | Number of query heads H |
| num_kv_heads | Number of key/value heads KV |
| head_dim | Head dimension AD |
| aligned_head_dim | Aligned head dimension AAD |
| cache_capacity | Maximum cache capacity max_T |
| eps | RMSNorm epsilon |
| scratch | Scratch buffer (>= scratch_size bytes) |
Definition at line 222 of file mega_fused_attention_decode_q5_0.c.
References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().
| void mega_fused_attention_decode_q5_0_parallel_simd | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const void * | wq_q5_0, | ||
| const void * | wk_q5_0, | ||
| const void * | wv_q8_0, | ||
| const void * | wo_q5_0, | ||
| const float * | ln_gamma, | ||
| const float * | bq, | ||
| const float * | bk, | ||
| const float * | bv, | ||
| const float * | bo, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | cache_capacity, | ||
| float | eps, | ||
| void * | scratch, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
Parallelizes across attention heads using (ith, nth) pattern. Each thread processes a subset of heads.
IMPORTANT: Caller must ensure barrier sync between phases: Phase 1 (ith==0 only): RMSNorm, Q/K/V projection, RoPE, KV cache store – BARRIER – Phase 2 (all threads): Attention for assigned heads – BARRIER – Phase 3 (ith==0 only): O projection and residual add
| ith | Thread index (0 to nth-1) |
| nth | Total number of threads (other parameters same as serial version) |
Definition at line 367 of file mega_fused_attention_decode_q5_0.c.
References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().
| int mega_fused_attention_decode_scratch_size | ( | int | AE, |
| int | H, | ||
| int | KV, | ||
| int | AD | ||
| ) |
Calculate scratch buffer size needed for the kernel.
| AE | Aligned embedding dimension (multiple of 64) |
| H | Number of query heads |
| KV | Number of key/value heads |
| AD | Head dimension |
Definition at line 176 of file mega_fused_attention_decode_q5_0.c.
References QK8_0.
| void quantize_row_q8_0 | ( | const float * | x, |
| void * | vy, | ||
| int | k | ||
| ) |
Quantize FP32 to Q8_0 format (scalar reference)
| x | Input FP32 values |
| vy | Output Q8_0 blocks |
| k | Number of elements (must be multiple of 32) |
Definition at line 59 of file gemm_kernels_q8_0.c.
Referenced by gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
| void rmsnorm_forward | ( | const float * | input, |
| const float * | gamma, | ||
| float * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
RMSNorm forward pass
test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens
test_rmsnorm.py::TestRMSNormForward::test_fp32_single
test_rmsnorm.py::TestRMSNormForward::test_perf_rolled
test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat
test_parity.py::test_rmsnorm_parity
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
After changes: make test && make llamacpp-parity-full
Definition at line 50 of file rmsnorm_kernels.c.
Referenced by mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
| void vec_dot_q5_0_q8_0 | ( | int | n, |
| float * | s, | ||
| const void * | vx, | ||
| const void * | vy | ||
| ) |
Auto-dispatch quantized dot product Q5_0 x Q8_0.
Dispatch priority:
Definition at line 1498 of file gemm_kernels_q5_0.c.
Referenced by gemv_q5_0_from_fp32(), mega_fused_attention_decode_q5_0(), and mega_fused_attention_decode_q5_0_parallel_simd().
| void vec_dot_q8_0_q8_0 | ( | int | n, |
| float * | s, | ||
| const void * | vx, | ||
| const void * | vy | ||
| ) |
Auto-dispatch quantized dot product Q8_0 x Q8_0.
Definition at line 1013 of file gemm_kernels_q8_0.c.
Referenced by gemv_q8_0_from_fp32().