Mega-fused attention decode with Q5_0 weights - Header. More...
Go to the source code of this file.
Functions | |
| void | mega_fused_attention_decode_q5_0 (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch) |
| Serial mega-fused attention decode kernel. More... | |
| void | mega_fused_attention_decode_q5_0_parallel_simd (float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth) |
| Parallel SIMD mega-fused attention decode kernel (threadpool-aware) More... | |
| int | mega_fused_attention_decode_scratch_size (int AE, int H, int KV, int AD) |
| Calculate scratch buffer size needed for the kernel. More... | |
Mega-fused attention decode with Q5_0 weights - Header.
This header declares the mega-fused attention decode kernel that combines 9 separate operations into a single fused kernel call:
Definition in file mega_fused_attention_decode_q5_0.h.
| void mega_fused_attention_decode_q5_0 | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const void * | wq_q5_0, | ||
| const void * | wk_q5_0, | ||
| const void * | wv_q8_0, | ||
| const void * | wo_q5_0, | ||
| const float * | ln_gamma, | ||
| const float * | bq, | ||
| const float * | bk, | ||
| const float * | bv, | ||
| const float * | bo, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | cache_capacity, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Serial mega-fused attention decode kernel.
| output | Output [AE] (final result, after residual add) |
| input | Input activation [AE] |
| residual | Residual input for add [AE] |
| wq_q5_0 | Q projection weights [H*AD, AE] Q5_0 |
| wk_q5_0 | K projection weights [KV*AD, AE] Q5_0 |
| wv_q8_0 | V projection weights [KV*AD, AE] Q8_0 |
| wo_q5_0 | O projection weights [AE, H*AD] Q5_0 |
| ln_gamma | RMSNorm gamma [AE] |
| bq | Q bias [H*AD] or NULL |
| bk | K bias [KV*AD] or NULL |
| bv | V bias [KV*AD] or NULL |
| bo | O bias [AE] or NULL |
| kv_cache_k | K cache [KV, max_T, AD] |
| kv_cache_v | V cache [KV, max_T, AD] |
| rope_cos | RoPE cos [max_T, D] |
| rope_sin | RoPE sin [max_T, D] |
| pos | Current position (0-indexed) |
| embed_dim | Original embedding dimension E |
| aligned_embed_dim | Aligned embedding dimension AE |
| num_heads | Number of query heads H |
| num_kv_heads | Number of key/value heads KV |
| head_dim | Head dimension AD |
| aligned_head_dim | Aligned head dimension AAD |
| cache_capacity | Maximum cache capacity max_T |
| eps | RMSNorm epsilon |
| scratch | Scratch buffer (>= scratch_size bytes) |
Definition at line 222 of file mega_fused_attention_decode_q5_0.c.
References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().
| void mega_fused_attention_decode_q5_0_parallel_simd | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const void * | wq_q5_0, | ||
| const void * | wk_q5_0, | ||
| const void * | wv_q8_0, | ||
| const void * | wo_q5_0, | ||
| const float * | ln_gamma, | ||
| const float * | bq, | ||
| const float * | bk, | ||
| const float * | bv, | ||
| const float * | bo, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | pos, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | cache_capacity, | ||
| float | eps, | ||
| void * | scratch, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
Parallelizes across attention heads using (ith, nth) pattern. Each thread processes a subset of heads.
IMPORTANT: Caller must ensure barrier sync between phases: Phase 1 (ith==0 only): RMSNorm, Q/K/V projection, RoPE, KV cache store – BARRIER – Phase 2 (all threads): Attention for assigned heads – BARRIER – Phase 3 (ith==0 only): O projection and residual add
| ith | Thread index (0 to nth-1) |
| nth | Total number of threads (other parameters same as serial version) |
Definition at line 367 of file mega_fused_attention_decode_q5_0.c.
References apply_rope_inline(), attention_forward_decode_head_major_gqa_flash(), gemv_q5_0_from_fp32(), gemv_q8_0_from_fp32(), QK5_0, quantize_row_q8_0(), rmsnorm_forward(), and vec_dot_q5_0_q8_0().
| int mega_fused_attention_decode_scratch_size | ( | int | AE, |
| int | H, | ||
| int | KV, | ||
| int | AD | ||
| ) |
Calculate scratch buffer size needed for the kernel.
| AE | Aligned embedding dimension (multiple of 64) |
| H | Number of query heads |
| KV | Number of key/value heads |
| AD | Head dimension |
Definition at line 176 of file mega_fused_attention_decode_q5_0.c.
References QK8_0.