Mega-Fused Attention + MLP Block. More...
#include <stdint.h>#include <stddef.h>#include <stdlib.h>#include <math.h>#include <string.h>#include "ckernel_quant.h"Go to the source code of this file.
Functions | |
| void | attention_mlp_fused_fp32 (const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out) |
| void | attention_mlp_fused_q4k (const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const void *wo, const float *residual_1, const float *rms_weight, float eps, const void *w_gate, const void *w_up, const void *w_down, int embed_dim, int intermediate_dim, float *hidden_out) |
| void | attention_mlp_separate_fp32 (const float *q, const float *k_cache, const float *v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float *wo, const float *residual_1, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *attn_out_buf, float *hidden_after_attn_buf, float *normed_buf, float *gate_buf, float *up_buf, float *mlp_out_buf, float *hidden_out) |
| static float | compute_rms_scale_internal (const float *x, int n, float eps) |
| void | layer_fused_attn_mlp_qkv_q4k (const float *q, const float *k_cache, const float *v_cache, int seq_len, float attn_scale, const void *wo, const float *rms_weight_mlp, const void *w_gate, const void *w_up, const void *w_down, const float *rms_weight_attn, const void *wq_next, const void *wk_next, const void *wv_next, const float *residual_in, int embed_dim, int intermediate_dim, int num_heads, int num_kv_heads, int head_dim, float eps, float *q_next, float *k_next, float *v_next, float *hidden_out) |
| void | mlp_fused_fp32_v2 (const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out) |
| void | mlp_fused_fp32_v3 (const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, int embed_dim, int intermediate_dim, float *hidden_out) |
| void | mlp_separate_fp32 (const float *hidden_in, const float *rms_weight, float eps, const float *w_gate, const float *w_up, const float *w_down, float *normed_buf, float *gate_buf, float *up_buf, int embed_dim, int intermediate_dim, float *hidden_out) |
| static float | silu_scalar (float x) |
| static void | softmax_inplace (float *x, int n) |
Mega-Fused Attention + MLP Block.
After changes: make test && make llamacpp-parity-full
VIOLATION: Uses memcpy for layout conversion. TODO: Use strided access.
Part of C-Kernel-Engine v6.6 Fusion Kernels
FUSES THE ENTIRE BLOCK from Attention output to next layer input:
Attention(Q, K_cache, V_cache) │ ▼ Output Projection (attn @ Wo) │ ▼
NON-FUSED version writes these buffers to DRAM:
FUSED version: ALL intermediates stay in L1/L2, ZERO DRAM writes
EXPECTED SPEEDUP: 2-3x for this block
Definition in file attention_mlp_fused.c.
| void attention_mlp_fused_fp32 | ( | const float * | q, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| int | seq_len, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| float | attn_scale, | ||
| const float * | wo, | ||
| const float * | residual_1, | ||
| const float * | rms_weight, | ||
| float | eps, | ||
| const float * | w_gate, | ||
| const float * | w_up, | ||
| const float * | w_down, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| float * | hidden_out | ||
| ) |
Definition at line 175 of file attention_mlp_fused.c.
References compute_rms_scale_internal(), score, silu_scalar(), and softmax_inplace().
| void attention_mlp_fused_q4k | ( | const float * | q, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| int | seq_len, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| float | attn_scale, | ||
| const void * | wo, | ||
| const float * | residual_1, | ||
| const float * | rms_weight, | ||
| float | eps, | ||
| const void * | w_gate, | ||
| const void * | w_up, | ||
| const void * | w_down, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| float * | hidden_out | ||
| ) |
Definition at line 742 of file attention_mlp_fused.c.
References compute_rms_scale_internal(), gemv_q4_k(), score, silu_scalar(), and softmax_inplace().
| void attention_mlp_separate_fp32 | ( | const float * | q, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| int | seq_len, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| float | attn_scale, | ||
| const float * | wo, | ||
| const float * | residual_1, | ||
| const float * | rms_weight, | ||
| float | eps, | ||
| const float * | w_gate, | ||
| const float * | w_up, | ||
| const float * | w_down, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| float * | attn_out_buf, | ||
| float * | hidden_after_attn_buf, | ||
| float * | normed_buf, | ||
| float * | gate_buf, | ||
| float * | up_buf, | ||
| float * | mlp_out_buf, | ||
| float * | hidden_out | ||
| ) |
Definition at line 1084 of file attention_mlp_fused.c.
References compute_rms_scale_internal(), score, silu_scalar(), and softmax_inplace().
|
inlinestatic |
Definition at line 76 of file attention_mlp_fused.c.
Referenced by attention_mlp_fused_fp32(), attention_mlp_fused_q4k(), attention_mlp_separate_fp32(), mlp_fused_fp32_v2(), mlp_fused_fp32_v3(), and mlp_separate_fp32().
| void layer_fused_attn_mlp_qkv_q4k | ( | const float * | q, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| int | seq_len, | ||
| float | attn_scale, | ||
| const void * | wo, | ||
| const float * | rms_weight_mlp, | ||
| const void * | w_gate, | ||
| const void * | w_up, | ||
| const void * | w_down, | ||
| const float * | rms_weight_attn, | ||
| const void * | wq_next, | ||
| const void * | wk_next, | ||
| const void * | wv_next, | ||
| const float * | residual_in, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| float | eps, | ||
| float * | q_next, | ||
| float * | k_next, | ||
| float * | v_next, | ||
| float * | hidden_out | ||
| ) |
| void mlp_fused_fp32_v2 | ( | const float * | hidden_in, |
| const float * | rms_weight, | ||
| float | eps, | ||
| const float * | w_gate, | ||
| const float * | w_up, | ||
| const float * | w_down, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| float * | hidden_out | ||
| ) |
Definition at line 407 of file attention_mlp_fused.c.
References compute_rms_scale_internal(), and silu_scalar().
| void mlp_fused_fp32_v3 | ( | const float * | hidden_in, |
| const float * | rms_weight, | ||
| float | eps, | ||
| const float * | w_gate, | ||
| const float * | w_up, | ||
| const float * | w_down, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| float * | hidden_out | ||
| ) |
Definition at line 564 of file attention_mlp_fused.c.
References compute_rms_scale_internal(), and silu_scalar().
| void mlp_separate_fp32 | ( | const float * | hidden_in, |
| const float * | rms_weight, | ||
| float | eps, | ||
| const float * | w_gate, | ||
| const float * | w_up, | ||
| const float * | w_down, | ||
| float * | normed_buf, | ||
| float * | gate_buf, | ||
| float * | up_buf, | ||
| int | embed_dim, | ||
| int | intermediate_dim, | ||
| float * | hidden_out | ||
| ) |
Definition at line 679 of file attention_mlp_fused.c.
References compute_rms_scale_internal(), and silu_scalar().
|
inlinestatic |
Definition at line 109 of file attention_mlp_fused.c.
Referenced by attention_mlp_fused_fp32(), attention_mlp_fused_q4k(), attention_mlp_separate_fp32(), mlp_fused_fp32_v2(), mlp_fused_fp32_v3(), and mlp_separate_fp32().
|
static |
Definition at line 150 of file attention_mlp_fused.c.
Referenced by attention_mlp_fused_fp32(), attention_mlp_fused_q4k(), and attention_mlp_separate_fp32().