← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_bf16.c File Reference

Optimized BF16 GEMM Kernels for AVX-512. More...

#include <stdint.h>
#include <string.h>
#include "bf16_utils.h"
#include "ckernel_engine.h"

Go to the source code of this file.

Macros

#define BLK_K   256
 
#define BLK_M   64
 
#define BLK_N   64
 

Functions

 __attribute__ ((unused))
 
static int ck_min_i (int a, int b)
 
void gemm_bf16_fp32out (const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_blocked_serial_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
 
void gemm_nn_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
 
void gemm_tn_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
 

Detailed Description

Optimized BF16 GEMM Kernels for AVX-512.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Layout: A: [M x K] row-major (BF16) B: [N x K] row-major, stored as [out x in] (BF16) C: [M x N] row-major (BF16 or FP32)

Key optimizations:

  1. AVX-512 BF16 instructions (VDPBF16PS) when available
  2. Cache blocking for L1/L2 efficiency
  3. Vectorized BF16<->FP32 conversion
  4. OpenMP parallelization

Definition in file gemm_kernels_bf16.c.

Macro Definition Documentation

◆ BLK_K

#define BLK_K   256

Definition at line 43 of file gemm_kernels_bf16.c.

◆ BLK_M

#define BLK_M   64

Definition at line 41 of file gemm_kernels_bf16.c.

◆ BLK_N

#define BLK_N   64

Definition at line 42 of file gemm_kernels_bf16.c.

Function Documentation

◆ __attribute__()

__attribute__ ( (unused)  )

Definition at line 51 of file gemm_kernels_bf16.c.

57 {
58  for (int i = 0; i < M; ++i) {
59  for (int j = 0; j < N; ++j) {
60  float sum = bias ? bf16_to_float(bias[j]) : 0.0f;
61  const size_t a_row = (size_t)i * (size_t)K;
62  const size_t b_row = (size_t)j * (size_t)K;
63  for (int k = 0; k < K; ++k) {
64  sum += bf16_to_float(A[a_row + k]) * bf16_to_float(B[b_row + k]);
65  }
66  C[(size_t)i * (size_t)N + j] = float_to_bf16(sum);
67  }
68  }
69 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
#define C(color)
Definition: show_config.c:39

References bf16_to_float(), C, and float_to_bf16().

◆ ck_min_i()

static int ck_min_i ( int  a,
int  b 
)
inlinestatic

Definition at line 45 of file gemm_kernels_bf16.c.

45 { return a < b ? a : b; }

◆ gemm_bf16_fp32out()

void gemm_bf16_fp32out ( const uint16_t *  A,
const uint16_t *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 301 of file gemm_kernels_bf16.c.

306 {
307  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
308  return;
309  }
310 
311 #if defined(__AVX512F__)
312  #pragma omp parallel for schedule(dynamic)
313  for (int i = 0; i < M; ++i) {
314  const uint16_t *a_row = A + (size_t)i * K;
315 
316  for (int j = 0; j < N; ++j) {
317  const uint16_t *b_row = B + (size_t)j * K;
318 
319  __m512 sum_vec = _mm512_setzero_ps();
320 
321  int k = 0;
322  for (; k <= K - 16; k += 16) {
323  __m256i a_bf16 = _mm256_loadu_si256((const __m256i *)(a_row + k));
324  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(b_row + k));
325  sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
326  }
327 
328  float sum = _mm512_reduce_add_ps(sum_vec);
329 
330  for (; k < K; ++k) {
331  sum += bf16_to_float(a_row[k]) * bf16_to_float(b_row[k]);
332  }
333 
334  if (bias) {
335  sum += bias[j];
336  }
337 
338  C[(size_t)i * N + j] = sum;
339  }
340  }
341 #else
342  for (int i = 0; i < M; ++i) {
343  for (int j = 0; j < N; ++j) {
344  float sum = bias ? bias[j] : 0.0f;
345  for (int k = 0; k < K; ++k) {
346  sum += bf16_to_float(A[(size_t)i * K + k]) *
347  bf16_to_float(B[(size_t)j * K + k]);
348  }
349  C[(size_t)i * N + j] = sum;
350  }
351  }
352 #endif
353 }

References bf16_to_float(), and C.

Referenced by mlp_token_parallel_bf16(), and mlp_token_parallel_bf16_fp32act().

◆ gemm_blocked_serial_bf16()

void gemm_blocked_serial_bf16 ( const uint16_t *  A,
const uint16_t *  B,
const uint16_t *  bias,
uint16_t *  C,
int  M,
int  N,
int  K 
)

Definition at line 272 of file gemm_kernels_bf16.c.

277 {
278  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
279  return;
280  }
281 
282 #if HAVE_NATIVE_BF16
283  /* Native BF16 instructions available (Ice Lake / Sapphire Rapids+) */
284  gemm_bf16_native(A, B, bias, C, M, N, K);
285 #elif defined(__AVX512F__)
286  /* Use AVX-512F with software BF16 conversion */
287  if (M * N > 4096) {
288  gemm_bf16_blocked_avx512(A, B, bias, C, M, N, K);
289  } else {
290  gemm_bf16_avx512(A, B, bias, C, M, N, K);
291  }
292 #else
293  /* Scalar fallback */
294  gemm_bf16_scalar(A, B, bias, C, M, N, K);
295 #endif
296 }

References C.

◆ gemm_nn_bf16()

void gemm_nn_bf16 ( const uint16_t *  A,
const uint16_t *  B,
const uint16_t *  bias,
uint16_t *  C,
int  M,
int  N,
int  K 
)

Definition at line 360 of file gemm_kernels_bf16.c.

365 {
366  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
367  return;
368  }
369 
370 #if defined(__AVX512F__)
371  #pragma omp parallel for
372  for (int i = 0; i < M; ++i) {
373  /* Initialize row with bias */
374  int j = 0;
375  for (; j <= N - 16; j += 16) {
376  __m512 b_vec = bias ? bf16x16_to_fp32(_mm256_loadu_si256((const __m256i *)(bias + j)))
377  : _mm512_setzero_ps();
378  __m256i out = fp32x16_to_bf16(b_vec);
379  _mm256_storeu_si256((__m256i *)(C + (size_t)i * N + j), out);
380  }
381  for (; j < N; ++j) {
382  float b = bias ? bf16_to_float(bias[j]) : 0.0f;
383  C[(size_t)i * N + j] = float_to_bf16(b);
384  }
385 
386  /* Accumulate: C[i,:] += A[i,k] * B[k,:] */
387  for (int k = 0; k < K; ++k) {
388  float a_val = bf16_to_float(A[(size_t)i * K + k]);
389  __m512 a_broadcast = _mm512_set1_ps(a_val);
390 
391  j = 0;
392  for (; j <= N - 16; j += 16) {
393  __m256i b_bf16 = _mm256_loadu_si256((const __m256i *)(B + (size_t)k * N + j));
394  __m512 b_fp32 = bf16x16_to_fp32(b_bf16);
395 
396  __m256i c_bf16 = _mm256_loadu_si256((const __m256i *)(C + (size_t)i * N + j));
397  __m512 c_fp32 = bf16x16_to_fp32(c_bf16);
398 
399  c_fp32 = _mm512_fmadd_ps(a_broadcast, b_fp32, c_fp32);
400 
401  __m256i c_out = fp32x16_to_bf16(c_fp32);
402  _mm256_storeu_si256((__m256i *)(C + (size_t)i * N + j), c_out);
403  }
404  for (; j < N; ++j) {
405  float c_val = bf16_to_float(C[(size_t)i * N + j]);
406  c_val += a_val * bf16_to_float(B[(size_t)k * N + j]);
407  C[(size_t)i * N + j] = float_to_bf16(c_val);
408  }
409  }
410  }
411 #else
412  /* Scalar fallback */
413  for (int i = 0; i < M; ++i) {
414  for (int j = 0; j < N; ++j) {
415  float sum = bias ? bf16_to_float(bias[j]) : 0.0f;
416  for (int k = 0; k < K; ++k) {
417  sum += bf16_to_float(A[(size_t)i * K + k]) *
418  bf16_to_float(B[(size_t)k * N + j]);
419  }
420  C[(size_t)i * N + j] = float_to_bf16(sum);
421  }
422  }
423 #endif
424 }

References bf16_to_float(), C, and float_to_bf16().

◆ gemm_tn_bf16()

void gemm_tn_bf16 ( const uint16_t *  A,
const uint16_t *  B,
const uint16_t *  bias,
uint16_t *  C,
int  M,
int  N,
int  K 
)

Definition at line 427 of file gemm_kernels_bf16.c.

432 {
433  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
434  return;
435  }
436 
437  /* A is [K x M], we want A.T which is [M x K] */
438  /* B is [K x N] */
439  /* C is [M x N] */
440 
441 #if defined(__AVX512F__)
442  /* Initialize C with bias */
443  #pragma omp parallel for
444  for (int i = 0; i < M; ++i) {
445  for (int j = 0; j < N; ++j) {
446  float b = bias ? bf16_to_float(bias[j]) : 0.0f;
447  C[(size_t)i * N + j] = float_to_bf16(b);
448  }
449  }
450 
451  /* Accumulate: C[i,j] += sum_k A[k,i] * B[k,j] */
452  #pragma omp parallel for
453  for (int i = 0; i < M; ++i) {
454  for (int j = 0; j < N; ++j) {
455  __m512 sum_vec = _mm512_setzero_ps();
456 
457  int k = 0;
458  for (; k <= K - 16; k += 16) {
459  /* Gather A[k:k+16, i] - strided access */
460  __m512 a_fp32 = _mm512_setzero_ps();
461  for (int kk = 0; kk < 16; ++kk) {
462  float val = bf16_to_float(A[(size_t)(k + kk) * M + i]);
463  a_fp32 = _mm512_mask_mov_ps(a_fp32, 1 << kk, _mm512_set1_ps(val));
464  }
465 
466  /* Note: B has stride N, so we need to gather element by element */
467  __m512 b_fp32 = _mm512_setzero_ps();
468  for (int kk = 0; kk < 16; ++kk) {
469  float val = bf16_to_float(B[(size_t)(k + kk) * N + j]);
470  b_fp32 = _mm512_mask_mov_ps(b_fp32, 1 << kk, _mm512_set1_ps(val));
471  }
472 
473  sum_vec = _mm512_fmadd_ps(a_fp32, b_fp32, sum_vec);
474  }
475 
476  float sum = _mm512_reduce_add_ps(sum_vec);
477 
478  for (; k < K; ++k) {
479  sum += bf16_to_float(A[(size_t)k * M + i]) *
480  bf16_to_float(B[(size_t)k * N + j]);
481  }
482 
483  float old_val = bf16_to_float(C[(size_t)i * N + j]);
484  C[(size_t)i * N + j] = float_to_bf16(old_val + sum);
485  }
486  }
487 #else
488  for (int i = 0; i < M; ++i) {
489  for (int j = 0; j < N; ++j) {
490  float sum = bias ? bf16_to_float(bias[j]) : 0.0f;
491  for (int k = 0; k < K; ++k) {
492  sum += bf16_to_float(A[(size_t)k * M + i]) *
493  bf16_to_float(B[(size_t)k * N + j]);
494  }
495  C[(size_t)i * N + j] = float_to_bf16(sum);
496  }
497  }
498 #endif
499 }

References bf16_to_float(), C, and float_to_bf16().