← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file rmsnorm_kernels_bf16.c
3  * @brief RMSNorm kernels for BF16 tensors
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
15  */
16 
17 #include "bf16_utils.h"
18 #include "ckernel_engine.h"
19 
20 #include <math.h>
21 #include <stdint.h>
22 
23 /* RMSNorm forward for BF16 inputs/outputs; gamma stays in float for precision */
24 void rmsnorm_forward_bf16(const uint16_t *input,
25  const float *gamma,
26  uint16_t *output,
27  float *rstd_cache,
28  int tokens,
29  int d_model,
30  int aligned_embed_dim,
31  float eps)
32 {
33  int T = tokens;
34  int D = d_model;
35  int aligned = aligned_embed_dim;
36 
37  for (int t = 0; t < T; ++t) {
38  const uint16_t *x_bf16 = input + (size_t)t * aligned;
39  float *rstd_ptr = rstd_cache ? (rstd_cache + t) : NULL;
40  uint16_t *out_bf16 = output + (size_t)t * aligned;
41 
42 #if defined(__AVX512F__)
43  // AVX-512: Process 16 floats at a time
44  __m512 sum_sq_vec = _mm512_setzero_ps();
45  int d = 0;
46 
47  // Vectorized sum of squares
48  for (; d + 16 <= D; d += 16) {
49  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
50  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
51  }
52  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
53 
54  // Handle remaining elements
55  for (; d < D; ++d) {
56  float x = bf16_to_float(x_bf16[d]);
57  sum_sq += x * x;
58  }
59 
60  float mean_sq = sum_sq / (float)D;
61  float rstd = 1.0f / sqrtf(mean_sq + eps);
62  if (rstd_ptr) {
63  *rstd_ptr = rstd;
64  }
65 
66  // Apply normalization and scale (vectorized)
67  __m512 rstd_vec = _mm512_set1_ps(rstd);
68  d = 0;
69  for (; d + 16 <= D; d += 16) {
70  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
71  __m512 gv = _mm512_loadu_ps(&gamma[d]);
72  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
73  __m512 yv = _mm512_mul_ps(x_hat, gv);
74  fp32_cvt_storeu_bf16(&out_bf16[d], yv);
75  }
76  // Handle remaining elements
77  for (; d < D; ++d) {
78  float x = bf16_to_float(x_bf16[d]);
79  float y = x * rstd * gamma[d];
80  out_bf16[d] = float_to_bf16(y);
81  }
82 
83 #else
84  // Scalar fallback
85  double sum_sq = 0.0;
86  for (int d = 0; d < D; ++d) {
87  float x = bf16_to_float(x_bf16[d]);
88  sum_sq += (double)x * (double)x;
89  }
90  double mean_sq = sum_sq / (double)D;
91  double r = sqrt(mean_sq + (double)eps);
92  float rstd = (float)(1.0 / r);
93  if (rstd_ptr) {
94  *rstd_ptr = rstd;
95  }
96 
97  for (int d = 0; d < D; ++d) {
98  float x = bf16_to_float(x_bf16[d]);
99  float x_hat = x * rstd;
100  float y = x_hat * gamma[d];
101  out_bf16[d] = float_to_bf16(y);
102  }
103 #endif
104 
105  // Zero padding
106  for (int d = D; d < aligned; ++d) {
107  out_bf16[d] = 0;
108  }
109  }
110 }
111 
112 // RMSNorm backward for BF16 inputs/outputs; gradients accumulate in float.
113 void rmsnorm_backward_bf16(const uint16_t *d_output,
114  const uint16_t *input,
115  const float *gamma,
116  const float *rstd_cache,
117  uint16_t *d_input,
118  float *d_gamma,
119  int tokens,
120  int d_model,
121  int aligned_embed_dim)
122 {
123  int T = tokens;
124  int D = d_model;
125  int aligned = aligned_embed_dim;
126 
127  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) {
128  return;
129  }
130 
131  // Zero parameter gradients
132 #if defined(__AVX512F__)
133  {
134  int d = 0;
135  for (; d + 16 <= D; d += 16) {
136  _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
137  }
138  for (; d < D; ++d) {
139  d_gamma[d] = 0.0f;
140  }
141  }
142 #else
143  for (int d = 0; d < D; ++d) {
144  d_gamma[d] = 0.0f;
145  }
146 #endif
147 
148  for (int t = 0; t < T; ++t) {
149  const uint16_t *x_bf16 = input + (size_t)t * aligned;
150  const uint16_t *dY_bf16 = d_output + (size_t)t * aligned;
151  uint16_t *dX_bf16 = d_input + (size_t)t * aligned;
152  float rstd = rstd_cache[t];
153 
154 #if defined(__AVX512F__)
155  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
156  __m512 rstd_vec = _mm512_set1_ps(rstd);
157  __m512 sum_vec = _mm512_setzero_ps();
158  int d = 0;
159 
160  for (; d + 16 <= D; d += 16) {
161  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
162  __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
163  __m512 gv = _mm512_loadu_ps(&gamma[d]);
164  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
165  // sum += dY * gamma * x_hat
166  __m512 prod = _mm512_mul_ps(dyv, gv);
167  sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
168  }
169  float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
170 
171  // Handle remaining elements
172  for (; d < D; ++d) {
173  float x = bf16_to_float(x_bf16[d]);
174  float x_hat = x * rstd;
175  float dy = bf16_to_float(dY_bf16[d]);
176  sum_dY_g_xhat += dy * gamma[d] * x_hat;
177  }
178  float m = sum_dY_g_xhat / (float)D;
179 
180  // Compute dX and accumulate dGamma (vectorized)
181  __m512 m_vec = _mm512_set1_ps(m);
182  d = 0;
183  for (; d + 16 <= D; d += 16) {
184  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
185  __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
186  __m512 gv = _mm512_loadu_ps(&gamma[d]);
187  __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
188 
189  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
190 
191  // dX = rstd * (dY * gamma - x_hat * m)
192  __m512 dy_g = _mm512_mul_ps(dyv, gv);
193  __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
194  __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
195  __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
196  fp32_cvt_storeu_bf16(&dX_bf16[d], dxv);
197 
198  // d_gamma += dY * x_hat
199  dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
200  _mm512_storeu_ps(&d_gamma[d], dgv);
201  }
202  // Handle remaining elements
203  for (; d < D; ++d) {
204  float x = bf16_to_float(x_bf16[d]);
205  float x_hat = x * rstd;
206  float dy = bf16_to_float(dY_bf16[d]);
207  float dx = rstd * (dy * gamma[d] - x_hat * m);
208  dX_bf16[d] = float_to_bf16(dx);
209  d_gamma[d] += dy * x_hat;
210  }
211 
212 #else
213  // Scalar fallback
214  double sum_dY_g_xhat = 0.0;
215  for (int d = 0; d < D; ++d) {
216  float x = bf16_to_float(x_bf16[d]);
217  float x_hat = x * rstd;
218  float dy = bf16_to_float(dY_bf16[d]);
219  sum_dY_g_xhat += (double)dy * (double)gamma[d] * (double)x_hat;
220  }
221  float m = (float)(sum_dY_g_xhat / (double)D);
222 
223  for (int d = 0; d < D; ++d) {
224  float x = bf16_to_float(x_bf16[d]);
225  float x_hat = x * rstd;
226  float dy = bf16_to_float(dY_bf16[d]);
227  float dx = rstd * (dy * gamma[d] - x_hat * m);
228  dX_bf16[d] = float_to_bf16(dx);
229  d_gamma[d] += dy * x_hat;
230  }
231 #endif
232 
233  // Zero padding gradients
234  for (int d = D; d < aligned; ++d) {
235  dX_bf16[d] = 0;
236  }
237  }
238 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
void rmsnorm_backward_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void rmsnorm_forward_bf16(const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)