← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mega_fused_attention_decode_q5_0.h
Go to the documentation of this file.
1 /**
2  * @file mega_fused_attention_decode_q5_0.h
3  * @brief Mega-fused attention decode with Q5_0 weights - Header
4  *
5  * This header declares the mega-fused attention decode kernel that combines
6  * 9 separate operations into a single fused kernel call:
7  * 1. RMSNorm
8  * 2. Q projection (Q5_0) with bias
9  * 3. K projection (Q5_0) with bias
10  * 4. V projection (Q8_0) with bias
11  * 5. RoPE application
12  * 6. KV cache store
13  * 7. Flash attention decode (GQA-aware)
14  * 8. O projection (Q5_0) with bias
15  * 9. Residual add
16  */
17 
18 #ifndef MEGA_FUSED_ATTENTION_DECODE_Q5_0_H
19 #define MEGA_FUSED_ATTENTION_DECODE_Q5_0_H
20 
21 #ifdef __cplusplus
22 extern "C" {
23 #endif
24 
25 /**
26  * @brief Calculate scratch buffer size needed for the kernel
27  *
28  * @param AE Aligned embedding dimension (multiple of 64)
29  * @param H Number of query heads
30  * @param KV Number of key/value heads
31  * @param AD Head dimension
32  * @return Size in bytes needed for scratch buffer
33  */
34 int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD);
35 
36 /**
37  * @brief Serial mega-fused attention decode kernel
38  *
39  * @param output Output [AE] (final result, after residual add)
40  * @param input Input activation [AE]
41  * @param residual Residual input for add [AE]
42  * @param wq_q5_0 Q projection weights [H*AD, AE] Q5_0
43  * @param wk_q5_0 K projection weights [KV*AD, AE] Q5_0
44  * @param wv_q8_0 V projection weights [KV*AD, AE] Q8_0
45  * @param wo_q5_0 O projection weights [AE, H*AD] Q5_0
46  * @param ln_gamma RMSNorm gamma [AE]
47  * @param bq Q bias [H*AD] or NULL
48  * @param bk K bias [KV*AD] or NULL
49  * @param bv V bias [KV*AD] or NULL
50  * @param bo O bias [AE] or NULL
51  * @param kv_cache_k K cache [KV, max_T, AD]
52  * @param kv_cache_v V cache [KV, max_T, AD]
53  * @param rope_cos RoPE cos [max_T, D]
54  * @param rope_sin RoPE sin [max_T, D]
55  * @param pos Current position (0-indexed)
56  * @param embed_dim Original embedding dimension E
57  * @param aligned_embed_dim Aligned embedding dimension AE
58  * @param num_heads Number of query heads H
59  * @param num_kv_heads Number of key/value heads KV
60  * @param head_dim Head dimension AD
61  * @param aligned_head_dim Aligned head dimension AAD
62  * @param cache_capacity Maximum cache capacity max_T
63  * @param eps RMSNorm epsilon
64  * @param scratch Scratch buffer (>= scratch_size bytes)
65  */
67  float *output,
68  const float *input,
69  const float *residual,
70  const void *wq_q5_0,
71  const void *wk_q5_0,
72  const void *wv_q8_0,
73  const void *wo_q5_0,
74  const float *ln_gamma,
75  const float *bq,
76  const float *bk,
77  const float *bv,
78  const float *bo,
79  float *kv_cache_k,
80  float *kv_cache_v,
81  const float *rope_cos,
82  const float *rope_sin,
83  int pos,
84  int embed_dim,
85  int aligned_embed_dim,
86  int num_heads,
87  int num_kv_heads,
88  int head_dim,
89  int aligned_head_dim,
90  int cache_capacity,
91  float eps,
92  void *scratch);
93 
94 /**
95  * @brief Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
96  *
97  * Parallelizes across attention heads using (ith, nth) pattern.
98  * Each thread processes a subset of heads.
99  *
100  * IMPORTANT: Caller must ensure barrier sync between phases:
101  * Phase 1 (ith==0 only): RMSNorm, Q/K/V projection, RoPE, KV cache store
102  * -- BARRIER --
103  * Phase 2 (all threads): Attention for assigned heads
104  * -- BARRIER --
105  * Phase 3 (ith==0 only): O projection and residual add
106  *
107  * @param ith Thread index (0 to nth-1)
108  * @param nth Total number of threads
109  * (other parameters same as serial version)
110  */
112  float *output,
113  const float *input,
114  const float *residual,
115  const void *wq_q5_0,
116  const void *wk_q5_0,
117  const void *wv_q8_0,
118  const void *wo_q5_0,
119  const float *ln_gamma,
120  const float *bq,
121  const float *bk,
122  const float *bv,
123  const float *bo,
124  float *kv_cache_k,
125  float *kv_cache_v,
126  const float *rope_cos,
127  const float *rope_sin,
128  int pos,
129  int embed_dim,
130  int aligned_embed_dim,
131  int num_heads,
132  int num_kv_heads,
133  int head_dim,
134  int aligned_head_dim,
135  int cache_capacity,
136  float eps,
137  void *scratch,
138  int ith,
139  int nth);
140 
141 #ifdef __cplusplus
142 }
143 #endif
144 
145 #endif /* MEGA_FUSED_ATTENTION_DECODE_Q5_0_H */
void mega_fused_attention_decode_q5_0_parallel_simd(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth)
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
void mega_fused_attention_decode_q5_0(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch)
Serial mega-fused attention decode kernel.
int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD)
Calculate scratch buffer size needed for the kernel.