← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mlp_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file mlp_kernels_bf16.c
3  * @brief Optimized BF16 MLP Kernels
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Uses direct BF16 GEMM instead of converting to FP32.
15  * Layout: input[T,D] -> fc1[T,4D] -> GELU -> fc2[T,D]
16  *
17  * All functions use caller-provided scratch buffers (no internal malloc).
18  */
19 
20 #include <stddef.h>
21 #include <stdint.h>
22 #include <math.h>
23 
24 #if defined(__AVX512F__)
25 #include <immintrin.h>
26 #endif
27 
28 #ifdef _OPENMP
29 #include <omp.h>
30 #endif
31 
32 #include "bf16_utils.h"
33 #include "ckernel_engine.h"
34 
35 // Suppress false positive warnings about uninitialized variables
36 #pragma GCC diagnostic push
37 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
38 
39 /* Forward declaration of optimized BF16 GEMM */
40 extern void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B,
41  const float *bias, float *C,
42  int M, int N, int K);
43 
44 /* GELU activation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) */
45 static inline float gelu_scalar(float x)
46 {
47  const float c = 0.7978845608f; /* sqrt(2/pi) */
48  const float k = 0.044715f;
49  float x3 = x * x * x;
50  return 0.5f * x * (1.0f + tanhf(c * (x + k * x3)));
51 }
52 
53 #if defined(__AVX512F__)
54 /* Vectorized GELU using polynomial approximation of tanh */
55 static inline __m512 gelu_avx512(__m512 x)
56 {
57  const __m512 c = _mm512_set1_ps(0.7978845608f);
58  const __m512 k = _mm512_set1_ps(0.044715f);
59  const __m512 half = _mm512_set1_ps(0.5f);
60  const __m512 one = _mm512_set1_ps(1.0f);
61 
62  __m512 x2 = _mm512_mul_ps(x, x);
63  __m512 x3 = _mm512_mul_ps(x2, x);
64 
65  __m512 inner = _mm512_fmadd_ps(k, x3, x);
66  inner = _mm512_mul_ps(c, inner);
67 
68  __m512 inner2 = _mm512_mul_ps(inner, inner);
69  __m512 num = _mm512_add_ps(_mm512_set1_ps(27.0f), inner2);
70  __m512 den = _mm512_fmadd_ps(_mm512_set1_ps(9.0f), inner2, _mm512_set1_ps(27.0f));
71  __m512 tanh_approx = _mm512_mul_ps(inner, _mm512_div_ps(num, den));
72 
73  tanh_approx = _mm512_min_ps(tanh_approx, one);
74  tanh_approx = _mm512_max_ps(tanh_approx, _mm512_set1_ps(-1.0f));
75 
76  __m512 result = _mm512_add_ps(one, tanh_approx);
77  result = _mm512_mul_ps(half, _mm512_mul_ps(x, result));
78 
79  return result;
80 }
81 #endif
82 
83 /**
84  * Optimized MLP Forward (BF16 weights, FP32 activations)
85  *
86  * Caller-provided scratch buffers:
87  * scratch_bias1_f: [4*D] floats
88  * scratch_bias2_f: [D] floats
89  * scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)
90  */
91 void mlp_token_parallel_bf16(const uint16_t *input,
92  const uint16_t *W_fc1,
93  const uint16_t *b_fc1,
94  const uint16_t *W_fc2,
95  const uint16_t *b_fc2,
96  float *fc1_output,
97  float *output,
98  int T,
99  int aligned_dim,
100  int num_threads,
101  float *scratch_bias1_f,
102  float *scratch_bias2_f,
103  uint16_t *scratch_fc1_bf16)
104 {
105  if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output) return;
106  if (!scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16) return;
107 
108  (void)num_threads;
109  const int D = aligned_dim;
110  const int fourD = 4 * D;
111 
112  /* Convert biases to FP32 */
113  for (int i = 0; i < fourD; ++i) {
114  scratch_bias1_f[i] = bf16_to_float(b_fc1[i]);
115  }
116  for (int i = 0; i < D; ++i) {
117  scratch_bias2_f[i] = bf16_to_float(b_fc2[i]);
118  }
119 
120  /* FC1: [T, D] x [4D, D].T -> [T, 4D] */
121  gemm_bf16_fp32out(input, W_fc1, scratch_bias1_f, fc1_output, T, fourD, D);
122 
123  /* GELU activation */
124 #if defined(__AVX512F__)
125  #pragma omp parallel for
126  for (int t = 0; t < T; ++t) {
127  float *row = fc1_output + (size_t)t * fourD;
128  int j = 0;
129  for (; j <= fourD - 16; j += 16) {
130  __m512 x = _mm512_loadu_ps(row + j);
131  _mm512_storeu_ps(row + j, gelu_avx512(x));
132  }
133  for (; j < fourD; ++j) {
134  row[j] = gelu_scalar(row[j]);
135  }
136  }
137 #else
138  for (int t = 0; t < T; ++t) {
139  for (int j = 0; j < fourD; ++j) {
140  fc1_output[t * fourD + j] = gelu_scalar(fc1_output[t * fourD + j]);
141  }
142  }
143 #endif
144 
145  /* Convert FP32 activations to BF16 */
146 #if defined(__AVX512F__)
147  #pragma omp parallel for
148  for (int t = 0; t < T; ++t) {
149  float *src = fc1_output + (size_t)t * fourD;
150  uint16_t *dst = scratch_fc1_bf16 + (size_t)t * fourD;
151  int j = 0;
152  for (; j <= fourD - 16; j += 16) {
153  __m512 fp32 = _mm512_loadu_ps(src + j);
154  __m512i as_int = _mm512_castps_si512(fp32);
155  __m512i lsb = _mm512_srli_epi32(as_int, 16);
156  lsb = _mm512_and_si512(lsb, _mm512_set1_epi32(1));
157  __m512i rounding = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb);
158  __m512i rounded = _mm512_add_epi32(as_int, rounding);
159  __m512i shifted = _mm512_srli_epi32(rounded, 16);
160  __m256i bf16 = _mm512_cvtepi32_epi16(shifted);
161  _mm256_storeu_si256((__m256i *)(dst + j), bf16);
162  }
163  for (; j < fourD; ++j) {
164  dst[j] = float_to_bf16(src[j]);
165  }
166  }
167 #else
168  for (size_t i = 0; i < (size_t)T * fourD; ++i) {
169  scratch_fc1_bf16[i] = float_to_bf16(fc1_output[i]);
170  }
171 #endif
172 
173  /* FC2: BF16 GEMM with FP32 output */
174  gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
175 }
176 
177 /**
178  * Alternative: Fully FP32 activations throughout
179  *
180  * Caller-provided scratch buffers:
181  * scratch_input_f: [T * D] floats
182  * scratch_bias1_f: [4*D] floats
183  * scratch_bias2_f: [D] floats
184  * scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)
185  */
186 void mlp_token_parallel_bf16_fp32act(const uint16_t *input,
187  const uint16_t *W_fc1,
188  const uint16_t *b_fc1,
189  const uint16_t *W_fc2,
190  const uint16_t *b_fc2,
191  float *fc1_output,
192  float *output,
193  int T,
194  int aligned_dim,
195  int num_threads,
196  float *scratch_input_f,
197  float *scratch_bias1_f,
198  float *scratch_bias2_f,
199  uint16_t *scratch_fc1_bf16)
200 {
201  if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output) return;
202  if (!scratch_input_f || !scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16) return;
203 
204  (void)num_threads;
205  const int D = aligned_dim;
206  const int fourD = 4 * D;
207 
208  /* Convert input and biases to FP32 */
209  bf16_tensor_to_float(input, scratch_input_f, (size_t)T * D);
210  bf16_tensor_to_float(b_fc1, scratch_bias1_f, fourD);
211  bf16_tensor_to_float(b_fc2, scratch_bias2_f, D);
212 
213  /* FC1 */
214  gemm_bf16_fp32out(input, W_fc1, scratch_bias1_f, fc1_output, T, fourD, D);
215 
216  /* GELU */
217 #if defined(__AVX512F__)
218  #pragma omp parallel for
219  for (int t = 0; t < T; ++t) {
220  float *row = fc1_output + (size_t)t * fourD;
221  int j = 0;
222  for (; j <= fourD - 16; j += 16) {
223  __m512 x = _mm512_loadu_ps(row + j);
224  _mm512_storeu_ps(row + j, gelu_avx512(x));
225  }
226  for (; j < fourD; ++j) {
227  row[j] = gelu_scalar(row[j]);
228  }
229  }
230 #else
231  for (size_t i = 0; i < (size_t)T * fourD; ++i) {
232  fc1_output[i] = gelu_scalar(fc1_output[i]);
233  }
234 #endif
235 
236  /* Convert fc1_output to BF16 for FC2 */
237  float_tensor_to_bf16(fc1_output, scratch_fc1_bf16, (size_t)T * fourD);
238  gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
239 }
240 
241 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
static float gelu_scalar(float x)
void mlp_token_parallel_bf16(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void mlp_token_parallel_bf16_fp32act(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)
#define C(color)
Definition: show_config.c:39