← Back to C-Kernel-Engine Docs Doxygen Source Documentation
qk_norm_kernels.c File Reference

Per-head RMSNorm on Q and K (Qwen3-style QK norm) More...

#include <stddef.h>

Go to the source code of this file.

Functions

void qk_norm_forward (float *q, float *k, const float *q_gamma, const float *k_gamma, int num_heads, int num_kv_heads, int num_tokens, int head_dim, float eps)
 
void rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 

Detailed Description

Per-head RMSNorm on Q and K (Qwen3-style QK norm)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && python unittest/test_qk_norm.py

QK Norm normalizes each head's query/key vectors independently before RoPE. This stabilizes Q*K^T dot products before softmax, preventing attention collapse from large magnitude vectors.

Why only Q and K, not V? V does not participate in the attention score computation (Q*K^T). The softmax saturation problem comes from large Q*K^T values, so only Q and K magnitudes matter. V is linearly combined after softmax weights are computed – normalizing it would change output scale but not fix attention stability.

Data layout after QKV projection (head-major): Q: [num_heads, num_tokens, head_dim] contiguous K: [num_kv_heads, num_tokens, head_dim] contiguous

We treat Q as [num_heads * num_tokens] rows of [head_dim] elements. rmsnorm_forward normalizes each row independently. The gamma weight [head_dim] is shared across all heads (Qwen3 design: one gamma per Q, one per K).

Definition in file qk_norm_kernels.c.

Function Documentation

◆ qk_norm_forward()

void qk_norm_forward ( float *  q,
float *  k,
const float *  q_gamma,
const float *  k_gamma,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
float  eps 
)

Per-head RMSNorm on Q and K.

Parameters
qQ scratch buffer [num_heads * num_tokens * head_dim], in-place
kK scratch buffer [num_kv_heads * num_tokens * head_dim], in-place
q_gammaQ norm gamma weights [head_dim]
k_gammaK norm gamma weights [head_dim]
num_headsNumber of query heads (e.g. 32 for Qwen3-8B)
num_kv_headsNumber of KV heads (e.g. 8 for Qwen3-8B with GQA)
num_tokensNumber of tokens (1 for decode, T for prefill)
head_dimDimension per head (e.g. 128)
epsRMSNorm epsilon (e.g. 1e-6)
Test:
unittest/test_qk_norm.py

Definition at line 61 of file qk_norm_kernels.c.

65 {
66  /* Q norm: [num_heads * num_tokens] rows of [head_dim]
67  * Each row is one head's vector for one token. */
68  rmsnorm_forward(q, q_gamma, q, NULL,
69  num_heads * num_tokens, head_dim, head_dim, eps);
70 
71  /* K norm: [num_kv_heads * num_tokens] rows of [head_dim]
72  * Same logic, fewer rows when using GQA. */
73  rmsnorm_forward(k, k_gamma, k, NULL,
74  num_kv_heads * num_tokens, head_dim, head_dim, eps);
75 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

References rmsnorm_forward().

◆ rmsnorm_forward()

void rmsnorm_forward ( const float *  input,
const float *  gamma,
float *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

RMSNorm forward pass

Test:

test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens

test_rmsnorm.py::TestRMSNormForward::test_fp32_single

test_rmsnorm.py::TestRMSNormForward::test_perf_rolled

test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat

test_parity.py::test_rmsnorm_parity

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

After changes: make test && make llamacpp-parity-full

Definition at line 50 of file rmsnorm_kernels.c.

58 {
59  int T = tokens;
60  int D = d_model;
61  int aligned = aligned_embed_dim;
62 
63  for (int t = 0; t < T; ++t) {
64  const float *x = input + (size_t)t * aligned;
65  float *y = output + (size_t)t * aligned;
66 
67 #if defined(__AVX512F__)
68  // AVX-512: Process 16 floats at a time
69  __m512 sum_sq_vec = _mm512_setzero_ps();
70  int d = 0;
71 
72  // Vectorized sum of squares
73  for (; d + 16 <= D; d += 16) {
74  __m512 xv = _mm512_loadu_ps(&x[d]);
75  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
76  }
77  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
78 
79  // Handle remaining elements
80  for (; d < D; ++d) {
81  sum_sq += x[d] * x[d];
82  }
83 
84  float mean_sq = sum_sq / (float)D;
85  float rstd = 1.0f / sqrtf(mean_sq + eps);
86  if (rstd_cache) {
87  rstd_cache[t] = rstd;
88  }
89 
90  // Apply normalization and scale (vectorized)
91  __m512 rstd_vec = _mm512_set1_ps(rstd);
92  d = 0;
93  for (; d + 16 <= D; d += 16) {
94  __m512 xv = _mm512_loadu_ps(&x[d]);
95  __m512 gv = _mm512_loadu_ps(&gamma[d]);
96  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97  __m512 yv = _mm512_mul_ps(x_hat, gv);
98  _mm512_storeu_ps(&y[d], yv);
99  }
100  // Handle remaining elements
101  for (; d < D; ++d) {
102  y[d] = x[d] * rstd * gamma[d];
103  }
104 
105 #elif defined(__AVX__)
106  // AVX: Process 8 floats at a time
107  __m256 sum_sq_vec = _mm256_setzero_ps();
108  int d = 0;
109 
110  // Vectorized sum of squares (no FMA in AVX1, use mul + add)
111  for (; d + 8 <= D; d += 8) {
112  __m256 xv = _mm256_loadu_ps(&x[d]);
113  __m256 xv_sq = _mm256_mul_ps(xv, xv);
114  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
115  }
116  float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
117 
118  // Handle remaining elements
119  for (; d < D; ++d) {
120  sum_sq += x[d] * x[d];
121  }
122 
123  float mean_sq = sum_sq / (float)D;
124  float rstd = 1.0f / sqrtf(mean_sq + eps);
125  if (rstd_cache) {
126  rstd_cache[t] = rstd;
127  }
128 
129  // Apply normalization and scale (vectorized)
130  __m256 rstd_vec = _mm256_set1_ps(rstd);
131  d = 0;
132  for (; d + 8 <= D; d += 8) {
133  __m256 xv = _mm256_loadu_ps(&x[d]);
134  __m256 gv = _mm256_loadu_ps(&gamma[d]);
135  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136  __m256 yv = _mm256_mul_ps(x_hat, gv);
137  _mm256_storeu_ps(&y[d], yv);
138  }
139  // Handle remaining elements
140  for (; d < D; ++d) {
141  y[d] = x[d] * rstd * gamma[d];
142  }
143 
144 #else
145  // Scalar fallback
146  double sum_sq = 0.0;
147  for (int d = 0; d < D; ++d) {
148  double v = (double)x[d];
149  sum_sq += v * v;
150  }
151  double mean_sq = sum_sq / (double)D;
152  double r = sqrt(mean_sq + (double)eps);
153  float rstd = (float)(1.0 / r);
154  if (rstd_cache) {
155  rstd_cache[t] = rstd;
156  }
157 
158  // Apply normalization and scale
159  for (int d = 0; d < D; ++d) {
160  float x_hat = x[d] * rstd;
161  y[d] = x_hat * gamma[d];
162  }
163 #endif
164 
165  // Zero padding (if any)
166  for (int d = D; d < aligned; ++d) {
167  y[d] = 0.0f;
168  }
169  }
170 }

Referenced by qk_norm_forward().