Output projection from head-major attention (NO LAYOUT CONVERSION) More...
#include <stdint.h>#include <stddef.h>#include <string.h>#include "ckernel_quant.h"#include "ckernel_dtype.h"Go to the source code of this file.
Functions | |
| void | ck_gemm_nt_head_major_q5_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim) |
| Output projection from head-major attention (auto-dispatch) More... | |
| void | ck_gemm_nt_head_major_q8_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim) |
| Output projection from head-major attention (Q8_0 weights) More... | |
| void | dequant_q5_0_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q5_0 row (multiple blocks) More... | |
| void | gemv_nt_q5_0_head_major_output (float *output, const float *attn_out, const void *wo, const float *bias, int tokens, int embed_dim, int num_heads, int head_dim) |
| Output projection reading head-major attention output (Q5_0 weights) More... | |
Output projection from head-major attention (NO LAYOUT CONVERSION)
After changes: make test && make llamacpp-parity-full
The standard mega_fused_attention_prefill has a bottleneck: attn_out [num_heads, tokens, head_dim] (head-major) → flatten_head_major() - 448 memcpy calls for 32 tokens × 14 heads! → token-major buffer → GEMM output projection
This kernel eliminates the flatten by reading head-major data directly with strided access. The output projection computes:
output[t, n] = bias[n] + sum_h wo[n, h*head_dim:(h+1)*head_dim] @ attn_out[h, t, :]
where wo is Q5_0 quantized [embed_dim, embed_dim] and attn_out is head-major.
Expected speedup: 1.5-2x by eliminating 448 small memcpy calls.
Definition in file gemm_head_major_output.c.
| void ck_gemm_nt_head_major_q5_0 | ( | const float * | attn_out, |
| const void * | wo, | ||
| const float * | bias, | ||
| float * | output, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | num_heads, | ||
| int | head_dim | ||
| ) |
Output projection from head-major attention (auto-dispatch)
This replaces flatten_head_major() + ck_gemm_nt_quant() with a single strided-access kernel that reads head-major attention output directly.
Definition at line 328 of file gemm_head_major_output.c.
References gemv_nt_q5_0_head_major_output().
Referenced by mega_fused_attention_prefill().
| void ck_gemm_nt_head_major_q8_0 | ( | const float * | attn_out, |
| const void * | wo, | ||
| const float * | bias, | ||
| float * | output, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | num_heads, | ||
| int | head_dim | ||
| ) |
Output projection from head-major attention (Q8_0 weights)
Definition at line 353 of file gemm_head_major_output.c.
References CK_FP16_TO_FP32, block_q8_0::d, QK8_0, and block_q8_0::qs.
Referenced by mega_fused_attention_prefill().
| void dequant_q5_0_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q5_0 row (multiple blocks)
Definition at line 196 of file dequant_kernels.c.
| void gemv_nt_q5_0_head_major_output | ( | float * | output, |
| const float * | attn_out, | ||
| const void * | wo, | ||
| const float * | bias, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | num_heads, | ||
| int | head_dim | ||
| ) |
Output projection reading head-major attention output (Q5_0 weights)
| output | Output [tokens, embed_dim] (token-major, written contiguously) |
| attn_out | Attention output [num_heads, tokens, head_dim] (head-major, strided) |
| wo | Output weights in Q5_0 format [embed_dim, embed_dim] |
| bias | Optional bias [embed_dim] |
| tokens | Number of tokens |
| embed_dim | Output embedding dimension |
| num_heads | Number of attention heads |
| head_dim | Head dimension (must be multiple of 32 for Q5_0) |
Definition at line 62 of file gemm_head_major_output.c.
References CK_FP16_TO_FP32, block_q5_0::d, block_q5_0::qh, QK5_0, and block_q5_0::qs.
Referenced by ck_gemm_nt_head_major_q5_0().