Per-head RMSNorm on Q and K (Qwen3-style QK norm) More...
#include <stddef.h>Go to the source code of this file.
Functions | |
| void | qk_norm_forward (float *q, float *k, const float *q_gamma, const float *k_gamma, int num_heads, int num_kv_heads, int num_tokens, int head_dim, float eps) |
| void | rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
Per-head RMSNorm on Q and K (Qwen3-style QK norm)
After changes: make test && python unittest/test_qk_norm.py
QK Norm normalizes each head's query/key vectors independently before RoPE. This stabilizes Q*K^T dot products before softmax, preventing attention collapse from large magnitude vectors.
Why only Q and K, not V? V does not participate in the attention score computation (Q*K^T). The softmax saturation problem comes from large Q*K^T values, so only Q and K magnitudes matter. V is linearly combined after softmax weights are computed – normalizing it would change output scale but not fix attention stability.
Data layout after QKV projection (head-major): Q: [num_heads, num_tokens, head_dim] contiguous K: [num_kv_heads, num_tokens, head_dim] contiguous
We treat Q as [num_heads * num_tokens] rows of [head_dim] elements. rmsnorm_forward normalizes each row independently. The gamma weight [head_dim] is shared across all heads (Qwen3 design: one gamma per Q, one per K).
Definition in file qk_norm_kernels.c.
| void qk_norm_forward | ( | float * | q, |
| float * | k, | ||
| const float * | q_gamma, | ||
| const float * | k_gamma, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| float | eps | ||
| ) |
Per-head RMSNorm on Q and K.
| q | Q scratch buffer [num_heads * num_tokens * head_dim], in-place |
| k | K scratch buffer [num_kv_heads * num_tokens * head_dim], in-place |
| q_gamma | Q norm gamma weights [head_dim] |
| k_gamma | K norm gamma weights [head_dim] |
| num_heads | Number of query heads (e.g. 32 for Qwen3-8B) |
| num_kv_heads | Number of KV heads (e.g. 8 for Qwen3-8B with GQA) |
| num_tokens | Number of tokens (1 for decode, T for prefill) |
| head_dim | Dimension per head (e.g. 128) |
| eps | RMSNorm epsilon (e.g. 1e-6) |
Definition at line 61 of file qk_norm_kernels.c.
References rmsnorm_forward().
| void rmsnorm_forward | ( | const float * | input, |
| const float * | gamma, | ||
| float * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
RMSNorm forward pass
test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens
test_rmsnorm.py::TestRMSNormForward::test_fp32_single
test_rmsnorm.py::TestRMSNormForward::test_perf_rolled
test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat
test_parity.py::test_rmsnorm_parity
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
After changes: make test && make llamacpp-parity-full
Definition at line 50 of file rmsnorm_kernels.c.
Referenced by qk_norm_forward().