← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_kernels_bf16.c File Reference

RMSNorm kernels for BF16 tensors. More...

#include "bf16_utils.h"
#include "ckernel_engine.h"
#include <math.h>
#include <stdint.h>

Go to the source code of this file.

Functions

void rmsnorm_backward_bf16 (const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
 
void rmsnorm_forward_bf16 (const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 

Detailed Description

RMSNorm kernels for BF16 tensors.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

Definition in file rmsnorm_kernels_bf16.c.

Function Documentation

◆ rmsnorm_backward_bf16()

void rmsnorm_backward_bf16 ( const uint16_t *  d_output,
const uint16_t *  input,
const float *  gamma,
const float *  rstd_cache,
uint16_t *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim 
)

Definition at line 113 of file rmsnorm_kernels_bf16.c.

122 {
123  int T = tokens;
124  int D = d_model;
125  int aligned = aligned_embed_dim;
126 
127  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) {
128  return;
129  }
130 
131  // Zero parameter gradients
132 #if defined(__AVX512F__)
133  {
134  int d = 0;
135  for (; d + 16 <= D; d += 16) {
136  _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
137  }
138  for (; d < D; ++d) {
139  d_gamma[d] = 0.0f;
140  }
141  }
142 #else
143  for (int d = 0; d < D; ++d) {
144  d_gamma[d] = 0.0f;
145  }
146 #endif
147 
148  for (int t = 0; t < T; ++t) {
149  const uint16_t *x_bf16 = input + (size_t)t * aligned;
150  const uint16_t *dY_bf16 = d_output + (size_t)t * aligned;
151  uint16_t *dX_bf16 = d_input + (size_t)t * aligned;
152  float rstd = rstd_cache[t];
153 
154 #if defined(__AVX512F__)
155  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
156  __m512 rstd_vec = _mm512_set1_ps(rstd);
157  __m512 sum_vec = _mm512_setzero_ps();
158  int d = 0;
159 
160  for (; d + 16 <= D; d += 16) {
161  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
162  __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
163  __m512 gv = _mm512_loadu_ps(&gamma[d]);
164  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
165  // sum += dY * gamma * x_hat
166  __m512 prod = _mm512_mul_ps(dyv, gv);
167  sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
168  }
169  float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
170 
171  // Handle remaining elements
172  for (; d < D; ++d) {
173  float x = bf16_to_float(x_bf16[d]);
174  float x_hat = x * rstd;
175  float dy = bf16_to_float(dY_bf16[d]);
176  sum_dY_g_xhat += dy * gamma[d] * x_hat;
177  }
178  float m = sum_dY_g_xhat / (float)D;
179 
180  // Compute dX and accumulate dGamma (vectorized)
181  __m512 m_vec = _mm512_set1_ps(m);
182  d = 0;
183  for (; d + 16 <= D; d += 16) {
184  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
185  __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
186  __m512 gv = _mm512_loadu_ps(&gamma[d]);
187  __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
188 
189  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
190 
191  // dX = rstd * (dY * gamma - x_hat * m)
192  __m512 dy_g = _mm512_mul_ps(dyv, gv);
193  __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
194  __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
195  __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
196  fp32_cvt_storeu_bf16(&dX_bf16[d], dxv);
197 
198  // d_gamma += dY * x_hat
199  dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
200  _mm512_storeu_ps(&d_gamma[d], dgv);
201  }
202  // Handle remaining elements
203  for (; d < D; ++d) {
204  float x = bf16_to_float(x_bf16[d]);
205  float x_hat = x * rstd;
206  float dy = bf16_to_float(dY_bf16[d]);
207  float dx = rstd * (dy * gamma[d] - x_hat * m);
208  dX_bf16[d] = float_to_bf16(dx);
209  d_gamma[d] += dy * x_hat;
210  }
211 
212 #else
213  // Scalar fallback
214  double sum_dY_g_xhat = 0.0;
215  for (int d = 0; d < D; ++d) {
216  float x = bf16_to_float(x_bf16[d]);
217  float x_hat = x * rstd;
218  float dy = bf16_to_float(dY_bf16[d]);
219  sum_dY_g_xhat += (double)dy * (double)gamma[d] * (double)x_hat;
220  }
221  float m = (float)(sum_dY_g_xhat / (double)D);
222 
223  for (int d = 0; d < D; ++d) {
224  float x = bf16_to_float(x_bf16[d]);
225  float x_hat = x * rstd;
226  float dy = bf16_to_float(dY_bf16[d]);
227  float dx = rstd * (dy * gamma[d] - x_hat * m);
228  dX_bf16[d] = float_to_bf16(dx);
229  d_gamma[d] += dy * x_hat;
230  }
231 #endif
232 
233  // Zero padding gradients
234  for (int d = D; d < aligned; ++d) {
235  dX_bf16[d] = 0;
236  }
237  }
238 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38

References bf16_to_float(), and float_to_bf16().

◆ rmsnorm_forward_bf16()

void rmsnorm_forward_bf16 ( const uint16_t *  input,
const float *  gamma,
uint16_t *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

Definition at line 24 of file rmsnorm_kernels_bf16.c.

32 {
33  int T = tokens;
34  int D = d_model;
35  int aligned = aligned_embed_dim;
36 
37  for (int t = 0; t < T; ++t) {
38  const uint16_t *x_bf16 = input + (size_t)t * aligned;
39  float *rstd_ptr = rstd_cache ? (rstd_cache + t) : NULL;
40  uint16_t *out_bf16 = output + (size_t)t * aligned;
41 
42 #if defined(__AVX512F__)
43  // AVX-512: Process 16 floats at a time
44  __m512 sum_sq_vec = _mm512_setzero_ps();
45  int d = 0;
46 
47  // Vectorized sum of squares
48  for (; d + 16 <= D; d += 16) {
49  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
50  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
51  }
52  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
53 
54  // Handle remaining elements
55  for (; d < D; ++d) {
56  float x = bf16_to_float(x_bf16[d]);
57  sum_sq += x * x;
58  }
59 
60  float mean_sq = sum_sq / (float)D;
61  float rstd = 1.0f / sqrtf(mean_sq + eps);
62  if (rstd_ptr) {
63  *rstd_ptr = rstd;
64  }
65 
66  // Apply normalization and scale (vectorized)
67  __m512 rstd_vec = _mm512_set1_ps(rstd);
68  d = 0;
69  for (; d + 16 <= D; d += 16) {
70  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
71  __m512 gv = _mm512_loadu_ps(&gamma[d]);
72  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
73  __m512 yv = _mm512_mul_ps(x_hat, gv);
74  fp32_cvt_storeu_bf16(&out_bf16[d], yv);
75  }
76  // Handle remaining elements
77  for (; d < D; ++d) {
78  float x = bf16_to_float(x_bf16[d]);
79  float y = x * rstd * gamma[d];
80  out_bf16[d] = float_to_bf16(y);
81  }
82 
83 #else
84  // Scalar fallback
85  double sum_sq = 0.0;
86  for (int d = 0; d < D; ++d) {
87  float x = bf16_to_float(x_bf16[d]);
88  sum_sq += (double)x * (double)x;
89  }
90  double mean_sq = sum_sq / (double)D;
91  double r = sqrt(mean_sq + (double)eps);
92  float rstd = (float)(1.0 / r);
93  if (rstd_ptr) {
94  *rstd_ptr = rstd;
95  }
96 
97  for (int d = 0; d < D; ++d) {
98  float x = bf16_to_float(x_bf16[d]);
99  float x_hat = x * rstd;
100  float y = x_hat * gamma[d];
101  out_bf16[d] = float_to_bf16(y);
102  }
103 #endif
104 
105  // Zero padding
106  for (int d = D; d < aligned; ++d) {
107  out_bf16[d] = 0;
108  }
109  }
110 }

References bf16_to_float(), and float_to_bf16().