← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mlp_kernels_bf16.c File Reference

Optimized BF16 MLP Kernels. More...

#include <stddef.h>
#include <stdint.h>
#include <math.h>
#include "bf16_utils.h"
#include "ckernel_engine.h"

Go to the source code of this file.

Functions

static float gelu_scalar (float x)
 
void gemm_bf16_fp32out (const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)
 
void mlp_token_parallel_bf16 (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
 
void mlp_token_parallel_bf16_fp32act (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
 

Detailed Description

Optimized BF16 MLP Kernels.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Uses direct BF16 GEMM instead of converting to FP32. Layout: input[T,D] -> fc1[T,4D] -> GELU -> fc2[T,D]

All functions use caller-provided scratch buffers (no internal malloc).

Definition in file mlp_kernels_bf16.c.

Function Documentation

◆ gelu_scalar()

static float gelu_scalar ( float  x)
inlinestatic

Definition at line 45 of file mlp_kernels_bf16.c.

46 {
47  const float c = 0.7978845608f; /* sqrt(2/pi) */
48  const float k = 0.044715f;
49  float x3 = x * x * x;
50  return 0.5f * x * (1.0f + tanhf(c * (x + k * x3)));
51 }

Referenced by mlp_token_parallel_bf16(), and mlp_token_parallel_bf16_fp32act().

◆ gemm_bf16_fp32out()

void gemm_bf16_fp32out ( const uint16_t *  A,
const uint16_t *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 301 of file gemm_kernels_bf16.c.

306 {
307  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
308  return;
309  }
310 
311 #if defined(__AVX512F__)
312  #pragma omp parallel for schedule(dynamic)
313  for (int i = 0; i < M; ++i) {
314  const uint16_t *a_row = A + (size_t)i * K;
315 
316  for (int j = 0; j < N; ++j) {
317  const uint16_t *b_row = B + (size_t)j * K;
318 
319  __m512 sum_vec = _mm512_setzero_ps();
320 
321  int k = 0;
322  for (; k <= K - 16; k += 16) {
323  __m256i a_bf16 = _mm256_loadu_si256((const __m256i *)(a_row + k));
324  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(b_row + k));
325  sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
326  }
327 
328  float sum = _mm512_reduce_add_ps(sum_vec);
329 
330  for (; k < K; ++k) {
331  sum += bf16_to_float(a_row[k]) * bf16_to_float(b_row[k]);
332  }
333 
334  if (bias) {
335  sum += bias[j];
336  }
337 
338  C[(size_t)i * N + j] = sum;
339  }
340  }
341 #else
342  for (int i = 0; i < M; ++i) {
343  for (int j = 0; j < N; ++j) {
344  float sum = bias ? bias[j] : 0.0f;
345  for (int k = 0; k < K; ++k) {
346  sum += bf16_to_float(A[(size_t)i * K + k]) *
347  bf16_to_float(B[(size_t)j * K + k]);
348  }
349  C[(size_t)i * N + j] = sum;
350  }
351  }
352 #endif
353 }
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
#define C(color)
Definition: show_config.c:39

References bf16_to_float(), and C.

Referenced by mlp_token_parallel_bf16(), and mlp_token_parallel_bf16_fp32act().

◆ mlp_token_parallel_bf16()

void mlp_token_parallel_bf16 ( const uint16_t *  input,
const uint16_t *  W_fc1,
const uint16_t *  b_fc1,
const uint16_t *  W_fc2,
const uint16_t *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads,
float *  scratch_bias1_f,
float *  scratch_bias2_f,
uint16_t *  scratch_fc1_bf16 
)

Optimized MLP Forward (BF16 weights, FP32 activations)

Caller-provided scratch buffers: scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)

Definition at line 91 of file mlp_kernels_bf16.c.

104 {
105  if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output) return;
106  if (!scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16) return;
107 
108  (void)num_threads;
109  const int D = aligned_dim;
110  const int fourD = 4 * D;
111 
112  /* Convert biases to FP32 */
113  for (int i = 0; i < fourD; ++i) {
114  scratch_bias1_f[i] = bf16_to_float(b_fc1[i]);
115  }
116  for (int i = 0; i < D; ++i) {
117  scratch_bias2_f[i] = bf16_to_float(b_fc2[i]);
118  }
119 
120  /* FC1: [T, D] x [4D, D].T -> [T, 4D] */
121  gemm_bf16_fp32out(input, W_fc1, scratch_bias1_f, fc1_output, T, fourD, D);
122 
123  /* GELU activation */
124 #if defined(__AVX512F__)
125  #pragma omp parallel for
126  for (int t = 0; t < T; ++t) {
127  float *row = fc1_output + (size_t)t * fourD;
128  int j = 0;
129  for (; j <= fourD - 16; j += 16) {
130  __m512 x = _mm512_loadu_ps(row + j);
131  _mm512_storeu_ps(row + j, gelu_avx512(x));
132  }
133  for (; j < fourD; ++j) {
134  row[j] = gelu_scalar(row[j]);
135  }
136  }
137 #else
138  for (int t = 0; t < T; ++t) {
139  for (int j = 0; j < fourD; ++j) {
140  fc1_output[t * fourD + j] = gelu_scalar(fc1_output[t * fourD + j]);
141  }
142  }
143 #endif
144 
145  /* Convert FP32 activations to BF16 */
146 #if defined(__AVX512F__)
147  #pragma omp parallel for
148  for (int t = 0; t < T; ++t) {
149  float *src = fc1_output + (size_t)t * fourD;
150  uint16_t *dst = scratch_fc1_bf16 + (size_t)t * fourD;
151  int j = 0;
152  for (; j <= fourD - 16; j += 16) {
153  __m512 fp32 = _mm512_loadu_ps(src + j);
154  __m512i as_int = _mm512_castps_si512(fp32);
155  __m512i lsb = _mm512_srli_epi32(as_int, 16);
156  lsb = _mm512_and_si512(lsb, _mm512_set1_epi32(1));
157  __m512i rounding = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb);
158  __m512i rounded = _mm512_add_epi32(as_int, rounding);
159  __m512i shifted = _mm512_srli_epi32(rounded, 16);
160  __m256i bf16 = _mm512_cvtepi32_epi16(shifted);
161  _mm256_storeu_si256((__m256i *)(dst + j), bf16);
162  }
163  for (; j < fourD; ++j) {
164  dst[j] = float_to_bf16(src[j]);
165  }
166  }
167 #else
168  for (size_t i = 0; i < (size_t)T * fourD; ++i) {
169  scratch_fc1_bf16[i] = float_to_bf16(fc1_output[i]);
170  }
171 #endif
172 
173  /* FC2: BF16 GEMM with FP32 output */
174  gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
175 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float gelu_scalar(float x)
void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)

References bf16_to_float(), float_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().

◆ mlp_token_parallel_bf16_fp32act()

void mlp_token_parallel_bf16_fp32act ( const uint16_t *  input,
const uint16_t *  W_fc1,
const uint16_t *  b_fc1,
const uint16_t *  W_fc2,
const uint16_t *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads,
float *  scratch_input_f,
float *  scratch_bias1_f,
float *  scratch_bias2_f,
uint16_t *  scratch_fc1_bf16 
)

Alternative: Fully FP32 activations throughout

Caller-provided scratch buffers: scratch_input_f: [T * D] floats scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)

Definition at line 186 of file mlp_kernels_bf16.c.

200 {
201  if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output) return;
202  if (!scratch_input_f || !scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16) return;
203 
204  (void)num_threads;
205  const int D = aligned_dim;
206  const int fourD = 4 * D;
207 
208  /* Convert input and biases to FP32 */
209  bf16_tensor_to_float(input, scratch_input_f, (size_t)T * D);
210  bf16_tensor_to_float(b_fc1, scratch_bias1_f, fourD);
211  bf16_tensor_to_float(b_fc2, scratch_bias2_f, D);
212 
213  /* FC1 */
214  gemm_bf16_fp32out(input, W_fc1, scratch_bias1_f, fc1_output, T, fourD, D);
215 
216  /* GELU */
217 #if defined(__AVX512F__)
218  #pragma omp parallel for
219  for (int t = 0; t < T; ++t) {
220  float *row = fc1_output + (size_t)t * fourD;
221  int j = 0;
222  for (; j <= fourD - 16; j += 16) {
223  __m512 x = _mm512_loadu_ps(row + j);
224  _mm512_storeu_ps(row + j, gelu_avx512(x));
225  }
226  for (; j < fourD; ++j) {
227  row[j] = gelu_scalar(row[j]);
228  }
229  }
230 #else
231  for (size_t i = 0; i < (size_t)T * fourD; ++i) {
232  fc1_output[i] = gelu_scalar(fc1_output[i]);
233  }
234 #endif
235 
236  /* Convert fc1_output to BF16 for FC2 */
237  float_tensor_to_bf16(fc1_output, scratch_fc1_bf16, (size_t)T * fourD);
238  gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
239 }
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250

References bf16_tensor_to_float(), float_tensor_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().