Mega-fused prefill attention kernel with Q8_0 out-proj. More...
#include "ckernel_engine.h"#include "ckernel_orchestration.h"#include "ckernel_quant.h"#include <math.h>#include <string.h>Go to the source code of this file.
Functions | |
| static size_t | align_up_size (size_t value, size_t align) |
| void | mega_fused_attention_prefill_q8_0 (float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch) |
| Mega-fused prefill attention kernel (Q8_0 out-proj) More... | |
| size_t | mega_fused_attention_prefill_q8_0_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| Get scratch buffer size for mega_fused_attention_prefill_q8_0. More... | |
| static void | out_proj_head_major_q8_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| static void | quantize_attn_out_head_major_q8_0 (const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim) |
Mega-fused prefill attention kernel with Q8_0 out-proj.
RMSNorm → QKV → RoPE → Flash Attention → Q8_0 OutProj + Residual Writes K/V directly into the KV cache layout (stride = cache_capacity).
Definition in file mega_fused_attention_prefill_q8_0.c.
|
static |
Definition at line 24 of file mega_fused_attention_prefill_q8_0.c.
Referenced by mega_fused_attention_prefill_q8_0(), and mega_fused_attention_prefill_q8_0_scratch_size().
| void mega_fused_attention_prefill_q8_0 | ( | float * | output, |
| const float * | input, | ||
| const float * | residual, | ||
| const float * | ln1_gamma, | ||
| const void * | wq, | ||
| const float * | bq, | ||
| CKDataType | wq_dt, | ||
| const void * | wk, | ||
| const float * | bk, | ||
| CKDataType | wk_dt, | ||
| const void * | wv, | ||
| const float * | bv, | ||
| CKDataType | wv_dt, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| CKDataType | wo_dt, | ||
| float * | kv_cache_k, | ||
| float * | kv_cache_v, | ||
| const float * | rope_cos, | ||
| const float * | rope_sin, | ||
| int | start_pos, | ||
| int | tokens, | ||
| int | cache_capacity, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Mega-fused prefill attention kernel (Q8_0 out-proj)
Same layout and scratch requirements as mega_fused_attention_prefill.
Definition at line 105 of file mega_fused_attention_prefill_q8_0.c.
References align_up_size(), attention_flash_decode(), attention_forward_causal_head_major_gqa_flash_strided(), CK_DT_FP32, CK_DT_Q8_0, ck_residual_add_token_major(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(), out_proj_head_major_q8_0_q8_0(), QK8_0, quantize_attn_out_head_major_q8_0(), and rope_forward_qk_strided().
| size_t mega_fused_attention_prefill_q8_0_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim | ||
| ) |
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
Definition at line 84 of file mega_fused_attention_prefill_q8_0.c.
References align_up_size(), and fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size().
|
static |
Definition at line 49 of file mega_fused_attention_prefill_q8_0.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), QK8_0, and vec_dot_q8_0_q8_0().
Referenced by mega_fused_attention_prefill_q8_0().
|
static |
Definition at line 29 of file mega_fused_attention_prefill_q8_0.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), and quantize_row_q8_0().
Referenced by mega_fused_attention_prefill_q8_0().