Mega-fused post-attention block for prefill. More...
#include "ckernel_engine.h"#include "ckernel_quant.h"#include <math.h>#include <stddef.h>#include <stdint.h>Go to the source code of this file.
Macros | |
| #define | OUTPROJ_TILE_N 8 |
| #define | OUTPROJ_TILE_N 8 |
Functions | |
| static size_t | align_up_size (size_t value, size_t align) |
| void | mega_fused_outproj_mlp_prefill (float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, CKDataType wo_dt, const void *w1, const float *b1, CKDataType w1_dt, const void *w2, const float *b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch) |
| Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill. More... | |
| size_t | mega_fused_outproj_mlp_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim) |
| Get scratch buffer size for mega_fused_outproj_mlp_prefill. More... | |
| static void | out_proj_head_major_q5_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| static void | out_proj_head_major_q8_0_q8_0 (const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim) |
| static void | quantize_attn_out_head_major_q8_0 (const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim) |
Mega-fused post-attention block for prefill.
OutProj → Residual → RMSNorm2 → MLP → Residual
Plan summary: 1) Quantize head-major attn_out to Q8_0 2) Out-proj with Q5_0/Q8_0 weights → h1 (post-attn) in scratch 3) Add residual (input) into h1 4) RMSNorm2(h1) → ln2_out (scratch) 5) Fused MLP (quant W1/W2) → output 6) Add h1 residual into output
Goal: avoid DRAM writes between attention out-proj and MLP output. All intermediates live in scratch buffers from the bump allocator.
Definition in file mega_fused_outproj_mlp_prefill.c.
| #define OUTPROJ_TILE_N 8 |
| #define OUTPROJ_TILE_N 8 |
|
static |
Definition at line 30 of file mega_fused_outproj_mlp_prefill.c.
Referenced by mega_fused_outproj_mlp_prefill(), and mega_fused_outproj_mlp_prefill_scratch_size().
| void mega_fused_outproj_mlp_prefill | ( | float * | output, |
| const float * | attn_out, | ||
| const float * | residual, | ||
| const float * | ln2_gamma, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| CKDataType | wo_dt, | ||
| const void * | w1, | ||
| const float * | b1, | ||
| CKDataType | w1_dt, | ||
| const void * | w2, | ||
| const float * | b2, | ||
| CKDataType | w2_dt, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim, | ||
| int | intermediate_dim, | ||
| int | aligned_intermediate_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
Uses head-major attention output and quantized out-proj (Q5_0/Q8_0 weights).
Definition at line 184 of file mega_fused_outproj_mlp_prefill.c.
References add_inplace_f32(), align_up_size(), CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q6_K, CK_DT_Q8_0, ck_dtype_row_bytes(), fused_mlp_swiglu_prefill_w1w2_quant(), out_proj_head_major_q5_0_q8_0(), out_proj_head_major_q8_0_q8_0(), QK_K, quantize_attn_out_head_major_q8_0(), and rmsnorm_forward().
| size_t mega_fused_outproj_mlp_prefill_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim, | ||
| int | aligned_intermediate_dim | ||
| ) |
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
Definition at line 159 of file mega_fused_outproj_mlp_prefill.c.
References align_up_size(), CK_DT_Q8_0, ck_dtype_row_bytes(), and fused_mlp_swiglu_prefill_w1w2_quant_scratch_size().
Referenced by ck_test_outproj_mlp_fused_q5_0().
|
static |
Definition at line 57 of file mega_fused_outproj_mlp_prefill.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), OUTPROJ_TILE_N, QK5_0, and vec_dot_q5_0_q8_0().
Referenced by mega_fused_outproj_mlp_prefill().
|
static |
Definition at line 108 of file mega_fused_outproj_mlp_prefill.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), OUTPROJ_TILE_N, QK8_0, and vec_dot_q8_0_q8_0().
Referenced by mega_fused_outproj_mlp_prefill().
|
static |
Definition at line 37 of file mega_fused_outproj_mlp_prefill.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), and quantize_row_q8_0().
Referenced by mega_fused_outproj_mlp_prefill().