← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_bf16.c
3  * @brief Optimized BF16 GEMM Kernels for AVX-512
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Layout:
15  * A: [M x K] row-major (BF16)
16  * B: [N x K] row-major, stored as [out x in] (BF16)
17  * C: [M x N] row-major (BF16 or FP32)
18  *
19  * Key optimizations:
20  * 1. AVX-512 BF16 instructions (VDPBF16PS) when available
21  * 2. Cache blocking for L1/L2 efficiency
22  * 3. Vectorized BF16<->FP32 conversion
23  * 4. OpenMP parallelization
24  */
25 
26 #include <stdint.h>
27 #include <string.h>
28 
29 #if defined(__AVX512F__)
30 #include <immintrin.h>
31 #endif
32 
33 #ifdef _OPENMP
34 #include <omp.h>
35 #endif
36 
37 #include "bf16_utils.h"
38 #include "ckernel_engine.h"
39 
40 /* Block sizes tuned for typical L1/L2 cache */
41 #define BLK_M 64
42 #define BLK_N 64
43 #define BLK_K 256
44 
45 static inline int ck_min_i(int a, int b) { return a < b ? a : b; }
46 
47 /* ==========================================================================
48  * Reference Implementation (scalar, for correctness testing)
49  * Kept for debugging/validation but not called in normal operation.
50  * ========================================================================== */
51 __attribute__((unused))
52 static void gemm_bf16_scalar(const uint16_t *A,
53  const uint16_t *B,
54  const uint16_t *bias,
55  uint16_t *C,
56  int M, int N, int K)
57 {
58  for (int i = 0; i < M; ++i) {
59  for (int j = 0; j < N; ++j) {
60  float sum = bias ? bf16_to_float(bias[j]) : 0.0f;
61  const size_t a_row = (size_t)i * (size_t)K;
62  const size_t b_row = (size_t)j * (size_t)K;
63  for (int k = 0; k < K; ++k) {
64  sum += bf16_to_float(A[a_row + k]) * bf16_to_float(B[b_row + k]);
65  }
66  C[(size_t)i * (size_t)N + j] = float_to_bf16(sum);
67  }
68  }
69 }
70 
71 #if defined(__AVX512F__)
72 
73 /* ==========================================================================
74  * AVX-512F: Vectorized BF16 conversion + FMA
75  * Works on all AVX-512 CPUs (no BF16 instruction required)
76  *
77  * BF16 conversion functions (bf16x16_to_fp32, fp32x16_to_bf16) are now
78  * provided by bf16_utils.h for consistency across all kernels.
79  * ========================================================================== */
80 
81 /* BF16 dot product: 16 pairs, accumulate to FP32 */
82 static inline __m512 bf16_dot16(__m256i a_bf16, __m256i b_bf16, __m512 acc)
83 {
84  __m512 a_fp32 = bf16x16_to_fp32(a_bf16);
85  __m512 b_fp32 = bf16x16_to_fp32(b_bf16);
86  return _mm512_fmadd_ps(a_fp32, b_fp32, acc);
87 }
88 
89 /* ==========================================================================
90  * AVX-512 Vectorized GEMM (using AVX-512F, works everywhere)
91  * C[M,N] = A[M,K] @ B[N,K].T
92  * ========================================================================== */
93 static void gemm_bf16_avx512(const uint16_t *A,
94  const uint16_t *B,
95  const uint16_t *bias,
96  uint16_t *C,
97  int M, int N, int K)
98 {
99  #pragma omp parallel for schedule(dynamic)
100  for (int i = 0; i < M; ++i) {
101  const uint16_t *a_row = A + (size_t)i * K;
102 
103  for (int j = 0; j < N; ++j) {
104  const uint16_t *b_row = B + (size_t)j * K;
105 
106  /* Initialize accumulator */
107  __m512 sum_vec = _mm512_setzero_ps();
108 
109  /* Vectorized inner loop: process 16 elements at a time */
110  int k = 0;
111  for (; k <= K - 16; k += 16) {
112  __m256i a_bf16 = _mm256_loadu_si256((const __m256i *)(a_row + k));
113  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(b_row + k));
114  sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
115  }
116 
117  /* Horizontal sum */
118  float sum = _mm512_reduce_add_ps(sum_vec);
119 
120  /* Scalar tail */
121  for (; k < K; ++k) {
122  sum += bf16_to_float(a_row[k]) * bf16_to_float(b_row[k]);
123  }
124 
125  /* Add bias */
126  if (bias) {
127  sum += bf16_to_float(bias[j]);
128  }
129 
130  C[(size_t)i * N + j] = float_to_bf16(sum);
131  }
132  }
133 }
134 
135 /* ==========================================================================
136  * Cache-Blocked AVX-512 GEMM
137  * Better memory access pattern for large matrices
138  * ========================================================================== */
139 static void gemm_bf16_blocked_avx512(const uint16_t *A,
140  const uint16_t *B,
141  const uint16_t *bias,
142  uint16_t *C,
143  int M, int N, int K)
144 {
145  /* Initialize C with bias */
146  #pragma omp parallel for
147  for (int i = 0; i < M; ++i) {
148  for (int j = 0; j < N; ++j) {
149  float b = bias ? bf16_to_float(bias[j]) : 0.0f;
150  C[(size_t)i * N + j] = float_to_bf16(b);
151  }
152  }
153 
154  /* Blocked GEMM */
155  #pragma omp parallel for collapse(2) schedule(dynamic)
156  for (int ii = 0; ii < M; ii += BLK_M) {
157  for (int jj = 0; jj < N; jj += BLK_N) {
158  int i_end = ck_min_i(ii + BLK_M, M);
159  int j_end = ck_min_i(jj + BLK_N, N);
160 
161  /* Local FP32 accumulator for this block */
162  float acc[BLK_M][BLK_N];
163  for (int i = 0; i < BLK_M; ++i) {
164  for (int j = 0; j < BLK_N; ++j) {
165  acc[i][j] = 0.0f;
166  }
167  }
168 
169  /* K-dimension blocking */
170  for (int kk = 0; kk < K; kk += BLK_K) {
171  int k_end = ck_min_i(kk + BLK_K, K);
172 
173  for (int i = ii; i < i_end; ++i) {
174  const uint16_t *a_row = A + (size_t)i * K;
175  int local_i = i - ii;
176 
177  for (int j = jj; j < j_end; ++j) {
178  const uint16_t *b_row = B + (size_t)j * K;
179  int local_j = j - jj;
180 
181  __m512 sum_vec = _mm512_setzero_ps();
182 
183  int k = kk;
184  for (; k <= k_end - 16; k += 16) {
185  __m256i a_bf16 = _mm256_loadu_si256((const __m256i *)(a_row + k));
186  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(b_row + k));
187  sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
188  }
189 
190  float partial = _mm512_reduce_add_ps(sum_vec);
191  for (; k < k_end; ++k) {
192  partial += bf16_to_float(a_row[k]) * bf16_to_float(b_row[k]);
193  }
194 
195  acc[local_i][local_j] += partial;
196  }
197  }
198  }
199 
200  /* Write accumulated results back */
201  for (int i = ii; i < i_end; ++i) {
202  for (int j = jj; j < j_end; ++j) {
203  float old_val = bf16_to_float(C[(size_t)i * N + j]);
204  float new_val = old_val + acc[i - ii][j - jj];
205  C[(size_t)i * N + j] = float_to_bf16(new_val);
206  }
207  }
208  }
209  }
210 }
211 
212 /*
213  * Native AVX-512 BF16 support (VDPBF16PS instruction)
214  * Only compiles on Ice Lake / Sapphire Rapids or newer
215  * Compile with: -mavx512bf16 (gcc/clang) or /arch:AVX512 (MSVC with recent SDK)
216  */
217 #if defined(__AVX512BF16__) && defined(__AVX512VL__)
218 
219 /* Load 32 BF16 values into __m512bh */
220 static inline __m512bh load_bf16x32(const uint16_t *ptr)
221 {
222  return (__m512bh)_mm512_loadu_si512((const __m512i *)ptr);
223 }
224 
225 static void gemm_bf16_native(const uint16_t *A,
226  const uint16_t *B,
227  const uint16_t *bias,
228  uint16_t *C,
229  int M, int N, int K)
230 {
231  #pragma omp parallel for schedule(dynamic)
232  for (int i = 0; i < M; ++i) {
233  for (int j = 0; j < N; ++j) {
234  /* Initialize accumulator */
235  __m512 sum_vec = _mm512_setzero_ps();
236 
237  /* Native BF16 dot product: 32 pairs per instruction! */
238  int k = 0;
239  for (; k <= K - 32; k += 32) {
240  __m512bh a_vec = load_bf16x32(A + (size_t)i * K + k);
241  __m512bh b_vec = load_bf16x32(B + (size_t)j * K + k);
242  sum_vec = _mm512_dpbf16_ps(sum_vec, a_vec, b_vec);
243  }
244 
245  float sum = _mm512_reduce_add_ps(sum_vec);
246 
247  /* Scalar tail */
248  for (; k < K; ++k) {
249  sum += bf16_to_float(A[(size_t)i * K + k]) *
250  bf16_to_float(B[(size_t)j * K + k]);
251  }
252 
253  if (bias) {
254  sum += bf16_to_float(bias[j]);
255  }
256 
257  C[(size_t)i * N + j] = float_to_bf16(sum);
258  }
259  }
260 }
261 
262 #define HAVE_NATIVE_BF16 1
263 #else
264 #define HAVE_NATIVE_BF16 0
265 #endif /* __AVX512BF16__ && __AVX512VL__ */
266 
267 #endif /* __AVX512F__ */
268 
269 /* ==========================================================================
270  * Public API: Auto-dispatch to best available implementation
271  * ========================================================================== */
272 void gemm_blocked_serial_bf16(const uint16_t *A,
273  const uint16_t *B,
274  const uint16_t *bias,
275  uint16_t *C,
276  int M, int N, int K)
277 {
278  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
279  return;
280  }
281 
282 #if HAVE_NATIVE_BF16
283  /* Native BF16 instructions available (Ice Lake / Sapphire Rapids+) */
284  gemm_bf16_native(A, B, bias, C, M, N, K);
285 #elif defined(__AVX512F__)
286  /* Use AVX-512F with software BF16 conversion */
287  if (M * N > 4096) {
288  gemm_bf16_blocked_avx512(A, B, bias, C, M, N, K);
289  } else {
290  gemm_bf16_avx512(A, B, bias, C, M, N, K);
291  }
292 #else
293  /* Scalar fallback */
294  gemm_bf16_scalar(A, B, bias, C, M, N, K);
295 #endif
296 }
297 
298 /* ==========================================================================
299  * GEMM with FP32 output (useful for intermediate computations)
300  * ========================================================================== */
301 void gemm_bf16_fp32out(const uint16_t *A,
302  const uint16_t *B,
303  const float *bias,
304  float *C,
305  int M, int N, int K)
306 {
307  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
308  return;
309  }
310 
311 #if defined(__AVX512F__)
312  #pragma omp parallel for schedule(dynamic)
313  for (int i = 0; i < M; ++i) {
314  const uint16_t *a_row = A + (size_t)i * K;
315 
316  for (int j = 0; j < N; ++j) {
317  const uint16_t *b_row = B + (size_t)j * K;
318 
319  __m512 sum_vec = _mm512_setzero_ps();
320 
321  int k = 0;
322  for (; k <= K - 16; k += 16) {
323  __m256i a_bf16 = _mm256_loadu_si256((const __m256i *)(a_row + k));
324  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(b_row + k));
325  sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
326  }
327 
328  float sum = _mm512_reduce_add_ps(sum_vec);
329 
330  for (; k < K; ++k) {
331  sum += bf16_to_float(a_row[k]) * bf16_to_float(b_row[k]);
332  }
333 
334  if (bias) {
335  sum += bias[j];
336  }
337 
338  C[(size_t)i * N + j] = sum;
339  }
340  }
341 #else
342  for (int i = 0; i < M; ++i) {
343  for (int j = 0; j < N; ++j) {
344  float sum = bias ? bias[j] : 0.0f;
345  for (int k = 0; k < K; ++k) {
346  sum += bf16_to_float(A[(size_t)i * K + k]) *
347  bf16_to_float(B[(size_t)j * K + k]);
348  }
349  C[(size_t)i * N + j] = sum;
350  }
351  }
352 #endif
353 }
354 
355 /* ==========================================================================
356  * Backward kernels for training
357  * ========================================================================== */
358 
359 /* gemm_nn_bf16: C = A @ B (no transpose), for dL/dX computation */
360 void gemm_nn_bf16(const uint16_t *A,
361  const uint16_t *B,
362  const uint16_t *bias,
363  uint16_t *C,
364  int M, int N, int K)
365 {
366  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
367  return;
368  }
369 
370 #if defined(__AVX512F__)
371  #pragma omp parallel for
372  for (int i = 0; i < M; ++i) {
373  /* Initialize row with bias */
374  int j = 0;
375  for (; j <= N - 16; j += 16) {
376  __m512 b_vec = bias ? bf16x16_to_fp32(_mm256_loadu_si256((const __m256i *)(bias + j)))
377  : _mm512_setzero_ps();
378  __m256i out = fp32x16_to_bf16(b_vec);
379  _mm256_storeu_si256((__m256i *)(C + (size_t)i * N + j), out);
380  }
381  for (; j < N; ++j) {
382  float b = bias ? bf16_to_float(bias[j]) : 0.0f;
383  C[(size_t)i * N + j] = float_to_bf16(b);
384  }
385 
386  /* Accumulate: C[i,:] += A[i,k] * B[k,:] */
387  for (int k = 0; k < K; ++k) {
388  float a_val = bf16_to_float(A[(size_t)i * K + k]);
389  __m512 a_broadcast = _mm512_set1_ps(a_val);
390 
391  j = 0;
392  for (; j <= N - 16; j += 16) {
393  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(B + (size_t)k * N + j));
394  __m512 b_fp32 = bf16x16_to_fp32(b_bf16);
395 
396  __m256i c_bf16 = _mm256_loadu_si256((const __m256i *)(C + (size_t)i * N + j));
397  __m512 c_fp32 = bf16x16_to_fp32(c_bf16);
398 
399  c_fp32 = _mm512_fmadd_ps(a_broadcast, b_fp32, c_fp32);
400 
401  __m256i c_out = fp32x16_to_bf16(c_fp32);
402  _mm256_storeu_si256((__m256i *)(C + (size_t)i * N + j), c_out);
403  }
404  for (; j < N; ++j) {
405  float c_val = bf16_to_float(C[(size_t)i * N + j]);
406  c_val += a_val * bf16_to_float(B[(size_t)k * N + j]);
407  C[(size_t)i * N + j] = float_to_bf16(c_val);
408  }
409  }
410  }
411 #else
412  /* Scalar fallback */
413  for (int i = 0; i < M; ++i) {
414  for (int j = 0; j < N; ++j) {
415  float sum = bias ? bf16_to_float(bias[j]) : 0.0f;
416  for (int k = 0; k < K; ++k) {
417  sum += bf16_to_float(A[(size_t)i * K + k]) *
418  bf16_to_float(B[(size_t)k * N + j]);
419  }
420  C[(size_t)i * N + j] = float_to_bf16(sum);
421  }
422  }
423 #endif
424 }
425 
426 /* gemm_tn_bf16: C = A.T @ B, for dL/dW computation */
427 void gemm_tn_bf16(const uint16_t *A,
428  const uint16_t *B,
429  const uint16_t *bias,
430  uint16_t *C,
431  int M, int N, int K)
432 {
433  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
434  return;
435  }
436 
437  /* A is [K x M], we want A.T which is [M x K] */
438  /* B is [K x N] */
439  /* C is [M x N] */
440 
441 #if defined(__AVX512F__)
442  /* Initialize C with bias */
443  #pragma omp parallel for
444  for (int i = 0; i < M; ++i) {
445  for (int j = 0; j < N; ++j) {
446  float b = bias ? bf16_to_float(bias[j]) : 0.0f;
447  C[(size_t)i * N + j] = float_to_bf16(b);
448  }
449  }
450 
451  /* Accumulate: C[i,j] += sum_k A[k,i] * B[k,j] */
452  #pragma omp parallel for
453  for (int i = 0; i < M; ++i) {
454  for (int j = 0; j < N; ++j) {
455  __m512 sum_vec = _mm512_setzero_ps();
456 
457  int k = 0;
458  for (; k <= K - 16; k += 16) {
459  /* Gather A[k:k+16, i] - strided access */
460  __m512 a_fp32 = _mm512_setzero_ps();
461  for (int kk = 0; kk < 16; ++kk) {
462  float val = bf16_to_float(A[(size_t)(k + kk) * M + i]);
463  a_fp32 = _mm512_mask_mov_ps(a_fp32, 1 << kk, _mm512_set1_ps(val));
464  }
465 
466  /* Note: B has stride N, so we need to gather element by element */
467  __m512 b_fp32 = _mm512_setzero_ps();
468  for (int kk = 0; kk < 16; ++kk) {
469  float val = bf16_to_float(B[(size_t)(k + kk) * N + j]);
470  b_fp32 = _mm512_mask_mov_ps(b_fp32, 1 << kk, _mm512_set1_ps(val));
471  }
472 
473  sum_vec = _mm512_fmadd_ps(a_fp32, b_fp32, sum_vec);
474  }
475 
476  float sum = _mm512_reduce_add_ps(sum_vec);
477 
478  for (; k < K; ++k) {
479  sum += bf16_to_float(A[(size_t)k * M + i]) *
480  bf16_to_float(B[(size_t)k * N + j]);
481  }
482 
483  float old_val = bf16_to_float(C[(size_t)i * N + j]);
484  C[(size_t)i * N + j] = float_to_bf16(old_val + sum);
485  }
486  }
487 #else
488  for (int i = 0; i < M; ++i) {
489  for (int j = 0; j < N; ++j) {
490  float sum = bias ? bf16_to_float(bias[j]) : 0.0f;
491  for (int k = 0; k < K; ++k) {
492  sum += bf16_to_float(A[(size_t)k * M + i]) *
493  bf16_to_float(B[(size_t)k * N + j]);
494  }
495  C[(size_t)i * N + j] = float_to_bf16(sum);
496  }
497  }
498 #endif
499 }
500 
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
#define BLK_M
void gemm_tn_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
__attribute__((unused))
#define BLK_N
void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)
void gemm_blocked_serial_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
static int ck_min_i(int a, int b)
#define BLK_K
void gemm_nn_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
#define C(color)
Definition: show_config.c:39