← Back to C-Kernel-Engine Docs Doxygen Source Documentation
qk_norm_kernels.c
Go to the documentation of this file.
1 /**
2  * @file qk_norm_kernels.c
3  * @brief Per-head RMSNorm on Q and K (Qwen3-style QK norm)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && python unittest/test_qk_norm.py
13  *
14  * QK Norm normalizes each head's query/key vectors independently before RoPE.
15  * This stabilizes Q*K^T dot products before softmax, preventing attention
16  * collapse from large magnitude vectors.
17  *
18  * Why only Q and K, not V?
19  * V does not participate in the attention score computation (Q*K^T).
20  * The softmax saturation problem comes from large Q*K^T values, so only
21  * Q and K magnitudes matter. V is linearly combined after softmax weights
22  * are computed -- normalizing it would change output scale but not fix
23  * attention stability.
24  *
25  * Data layout after QKV projection (head-major):
26  * Q: [num_heads, num_tokens, head_dim] contiguous
27  * K: [num_kv_heads, num_tokens, head_dim] contiguous
28  *
29  * We treat Q as [num_heads * num_tokens] rows of [head_dim] elements.
30  * rmsnorm_forward normalizes each row independently. The gamma weight [head_dim]
31  * is shared across all heads (Qwen3 design: one gamma per Q, one per K).
32  */
33 
34 #include <stddef.h> /* NULL */
35 
36 /* rmsnorm_forward is declared in ckernel_engine.h */
37 void rmsnorm_forward(const float *input,
38  const float *gamma,
39  float *output,
40  float *rstd_cache,
41  int tokens,
42  int d_model,
43  int aligned_embed_dim,
44  float eps);
45 
46 /**
47  * Per-head RMSNorm on Q and K.
48  *
49  * @param q Q scratch buffer [num_heads * num_tokens * head_dim], in-place
50  * @param k K scratch buffer [num_kv_heads * num_tokens * head_dim], in-place
51  * @param q_gamma Q norm gamma weights [head_dim]
52  * @param k_gamma K norm gamma weights [head_dim]
53  * @param num_heads Number of query heads (e.g. 32 for Qwen3-8B)
54  * @param num_kv_heads Number of KV heads (e.g. 8 for Qwen3-8B with GQA)
55  * @param num_tokens Number of tokens (1 for decode, T for prefill)
56  * @param head_dim Dimension per head (e.g. 128)
57  * @param eps RMSNorm epsilon (e.g. 1e-6)
58  *
59  * @test unittest/test_qk_norm.py
60  */
61 void qk_norm_forward(float *q, float *k,
62  const float *q_gamma, const float *k_gamma,
63  int num_heads, int num_kv_heads,
64  int num_tokens, int head_dim, float eps)
65 {
66  /* Q norm: [num_heads * num_tokens] rows of [head_dim]
67  * Each row is one head's vector for one token. */
68  rmsnorm_forward(q, q_gamma, q, NULL,
69  num_heads * num_tokens, head_dim, head_dim, eps);
70 
71  /* K norm: [num_kv_heads * num_tokens] rows of [head_dim]
72  * Same logic, fewer rows when using GQA. */
73  rmsnorm_forward(k, k_gamma, k, NULL,
74  num_kv_heads * num_tokens, head_dim, head_dim, eps);
75 }
void qk_norm_forward(float *q, float *k, const float *q_gamma, const float *k_gamma, int num_heads, int num_kv_heads, int num_tokens, int head_dim, float eps)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)