18 #ifndef MEGA_FUSED_ATTENTION_DECODE_Q5_0_H
19 #define MEGA_FUSED_ATTENTION_DECODE_Q5_0_H
69 const float *residual,
74 const float *ln_gamma,
81 const float *rope_cos,
82 const float *rope_sin,
85 int aligned_embed_dim,
114 const float *residual,
119 const float *ln_gamma,
126 const float *rope_cos,
127 const float *rope_sin,
130 int aligned_embed_dim,
134 int aligned_head_dim,
void mega_fused_attention_decode_q5_0_parallel_simd(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch, int ith, int nth)
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
void mega_fused_attention_decode_q5_0(float *output, const float *input, const float *residual, const void *wq_q5_0, const void *wk_q5_0, const void *wv_q8_0, const void *wo_q5_0, const float *ln_gamma, const float *bq, const float *bk, const float *bv, const float *bo, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void *scratch)
Serial mega-fused attention decode kernel.
int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD)
Calculate scratch buffer size needed for the kernel.