Mega-fused prefill attention kernel. More...
#include "ckernel_engine.h"#include "ckernel_orchestration.h"#include "ckernel_quant.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <stdio.h>Go to the source code of this file.
Functions | |
| static size_t | align_up_size (size_t value, size_t align) |
| static int | ck_q8_0_outproj_enabled (void) |
| static void | flatten_head_major (const float *attn_out, float *dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| void | mega_fused_attention_prefill (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch) |
| Mega-fused attention for prefill mode (multiple tokens) More... | |
| size_t | mega_fused_attention_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| Get scratch buffer size for mega_fused_attention_prefill. More... | |
| static void | out_proj_head_major_q5_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| static void | quantize_attn_out_head_major_q8_0 (const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim) |
Mega-fused prefill attention kernel.
After changes: make test && make llamacpp-parity-full
RMSNorm → QKV → RoPE → Flash Attention → OutProj + Residual Writes K/V directly into the KV cache layout (stride = cache_capacity).
Uses ck_gemm_nt_head_major_*() to read head-major attention output directly with strided access, eliminating the flatten_head_major() memcpy bottleneck (448 memcpy calls for 32 tokens × 14 heads)
python3 scripts/bench_mega_fused_attention_prefill.py –q8-outproj –seq-lens 32,64 –iters 3 –warmup 1
Definition in file mega_fused_attention_prefill.c.
|
static |
Definition at line 39 of file mega_fused_attention_prefill.c.
Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_scratch_size().
|
static |
Definition at line 63 of file mega_fused_attention_prefill.c.
Referenced by mega_fused_attention_prefill().
|
static |
Definition at line 43 of file mega_fused_attention_prefill.c.
Referenced by mega_fused_attention_prefill().
| void mega_fused_attention_prefill | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | ln1_gamma, | ||
| const void * | wq, | ||
| const float * | bq, | ||
| CKDataType | wq_dt, | ||
| const void * | wk, | ||
| const float * | bk, | ||
| CKDataType | wk_dt, | ||
| const void * | wv, | ||
| const float * | bv, | ||
| CKDataType | wv_dt, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| CKDataType | wo_dt, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | start_pos, | ||
| int | tokens, | ||
| int | cache_capacity, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Mega-fused attention for prefill mode (multiple tokens)
| output | Output [tokens, aligned_embed_dim] (includes residual add) |
| input | Input [tokens, aligned_embed_dim] |
| residual | Residual input [tokens, aligned_embed_dim] (or NULL) |
| ln1_gamma | RMSNorm gamma [embed_dim] |
| wq | Q weights [num_heads * aligned_head_dim * aligned_embed_dim] |
| bq | Q bias [num_heads * aligned_head_dim] (or NULL) |
| wk | K weights [num_kv_heads * aligned_head_dim * aligned_embed_dim] |
| bk | K bias [num_kv_heads * aligned_head_dim] (or NULL) |
| wv | V weights [num_kv_heads * aligned_head_dim * aligned_embed_dim] |
| bv | V bias [num_kv_heads * aligned_head_dim] (or NULL) |
| wo | Output projection weights [num_heads * aligned_embed_dim * aligned_head_dim] |
| bo | Output bias [aligned_embed_dim] (or NULL) |
| kv_cache_k | KV cache for K [num_kv_heads * cache_capacity * aligned_head_dim] |
| kv_cache_v | KV cache for V [num_kv_heads * cache_capacity * aligned_head_dim] |
| rope_cos | RoPE cos [max_seq, head_dim/2] |
| rope_sin | RoPE sin [max_seq, head_dim/2] |
| start_pos | Starting position in KV cache |
| tokens | Number of tokens to process |
| cache_capacity | KV cache capacity (stride in tokens) |
| embed_dim | Model hidden dimension (unpadded) |
| aligned_embed_dim | Aligned hidden dimension |
| num_heads | Number of attention heads |
| num_kv_heads | Number of KV heads |
| head_dim | Head dimension (unpadded) |
| aligned_head_dim | Aligned head dimension |
| eps | RMSNorm epsilon |
Definition at line 160 of file mega_fused_attention_prefill.c.
References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q5_0, CK_DT_Q8_0, ck_gemm_nt_head_major_q5_0(), ck_gemm_nt_head_major_q8_0(), ck_gemm_nt_quant(), ck_q8_0_outproj_enabled(), ck_residual_add_token_major(), flatten_head_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q5_0_q8_0(), QK5_0, QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().
| size_t mega_fused_attention_prefill_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Get scratch buffer size for mega_fused_attention_prefill.
Definition at line 139 of file mega_fused_attention_prefill.c.
References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().
|
static |
Definition at line 104 of file mega_fused_attention_prefill.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), QK5_0, and vec_dot_q5_0_q8_0().
Referenced by mega_fused_attention_prefill().
|
static |
Definition at line 84 of file mega_fused_attention_prefill.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), and quantize_row_q8_0().
Referenced by mega_fused_attention_prefill().