29 #if defined(__AVX512F__)
30 #include <immintrin.h>
45 static inline int ck_min_i(
int a,
int b) {
return a < b ? a : b; }
52 static void gemm_bf16_scalar(
const uint16_t *A,
58 for (
int i = 0; i < M; ++i) {
59 for (
int j = 0; j < N; ++j) {
61 const size_t a_row = (size_t)i * (
size_t)K;
62 const size_t b_row = (size_t)j * (
size_t)K;
63 for (
int k = 0; k < K; ++k) {
71 #if defined(__AVX512F__)
82 static inline __m512 bf16_dot16(__m256i a_bf16, __m256i b_bf16, __m512 acc)
84 __m512 a_fp32 = bf16x16_to_fp32(a_bf16);
85 __m512 b_fp32 = bf16x16_to_fp32(b_bf16);
86 return _mm512_fmadd_ps(a_fp32, b_fp32, acc);
93 static void gemm_bf16_avx512(
const uint16_t *A,
99 #pragma omp parallel for schedule(dynamic)
100 for (
int i = 0; i < M; ++i) {
101 const uint16_t *a_row = A + (size_t)i * K;
103 for (
int j = 0; j < N; ++j) {
104 const uint16_t *b_row = B + (size_t)j * K;
107 __m512 sum_vec = _mm512_setzero_ps();
111 for (; k <= K - 16; k += 16) {
112 __m256i a_bf16 = _mm256_loadu_si256((
const __m256i *)(a_row + k));
113 __m256i b_bf16 = _mm256_loadu_si256((
const __m256i *)(b_row + k));
114 sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
118 float sum = _mm512_reduce_add_ps(sum_vec);
139 static void gemm_bf16_blocked_avx512(
const uint16_t *A,
141 const uint16_t *bias,
146 #pragma omp parallel for
147 for (
int i = 0; i < M; ++i) {
148 for (
int j = 0; j < N; ++j) {
155 #pragma omp parallel for collapse(2) schedule(dynamic)
156 for (
int ii = 0; ii < M; ii +=
BLK_M) {
157 for (
int jj = 0; jj < N; jj +=
BLK_N) {
163 for (
int i = 0; i <
BLK_M; ++i) {
164 for (
int j = 0; j <
BLK_N; ++j) {
170 for (
int kk = 0; kk < K; kk +=
BLK_K) {
173 for (
int i = ii; i < i_end; ++i) {
174 const uint16_t *a_row = A + (size_t)i * K;
175 int local_i = i - ii;
177 for (
int j = jj; j < j_end; ++j) {
178 const uint16_t *b_row = B + (size_t)j * K;
179 int local_j = j - jj;
181 __m512 sum_vec = _mm512_setzero_ps();
184 for (; k <= k_end - 16; k += 16) {
185 __m256i a_bf16 = _mm256_loadu_si256((
const __m256i *)(a_row + k));
186 __m256i b_bf16 = _mm256_loadu_si256((
const __m256i *)(b_row + k));
187 sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
190 float partial = _mm512_reduce_add_ps(sum_vec);
191 for (; k < k_end; ++k) {
195 acc[local_i][local_j] += partial;
201 for (
int i = ii; i < i_end; ++i) {
202 for (
int j = jj; j < j_end; ++j) {
204 float new_val = old_val + acc[i - ii][j - jj];
217 #if defined(__AVX512BF16__) && defined(__AVX512VL__)
220 static inline __m512bh load_bf16x32(
const uint16_t *ptr)
222 return (__m512bh)_mm512_loadu_si512((
const __m512i *)ptr);
225 static void gemm_bf16_native(
const uint16_t *A,
227 const uint16_t *bias,
231 #pragma omp parallel for schedule(dynamic)
232 for (
int i = 0; i < M; ++i) {
233 for (
int j = 0; j < N; ++j) {
235 __m512 sum_vec = _mm512_setzero_ps();
239 for (; k <= K - 32; k += 32) {
240 __m512bh a_vec = load_bf16x32(A + (
size_t)i * K + k);
241 __m512bh b_vec = load_bf16x32(B + (
size_t)j * K + k);
242 sum_vec = _mm512_dpbf16_ps(sum_vec, a_vec, b_vec);
245 float sum = _mm512_reduce_add_ps(sum_vec);
262 #define HAVE_NATIVE_BF16 1
264 #define HAVE_NATIVE_BF16 0
274 const uint16_t *bias,
278 if (!A || !B || !
C || M <= 0 || N <= 0 || K <= 0) {
284 gemm_bf16_native(A, B, bias,
C, M, N, K);
285 #elif defined(__AVX512F__)
288 gemm_bf16_blocked_avx512(A, B, bias,
C, M, N, K);
290 gemm_bf16_avx512(A, B, bias,
C, M, N, K);
294 gemm_bf16_scalar(A, B, bias,
C, M, N, K);
307 if (!A || !B || !
C || M <= 0 || N <= 0 || K <= 0) {
311 #if defined(__AVX512F__)
312 #pragma omp parallel for schedule(dynamic)
313 for (
int i = 0; i < M; ++i) {
314 const uint16_t *a_row = A + (size_t)i * K;
316 for (
int j = 0; j < N; ++j) {
317 const uint16_t *b_row = B + (size_t)j * K;
319 __m512 sum_vec = _mm512_setzero_ps();
322 for (; k <= K - 16; k += 16) {
323 __m256i a_bf16 = _mm256_loadu_si256((
const __m256i *)(a_row + k));
324 __m256i b_bf16 = _mm256_loadu_si256((
const __m256i *)(b_row + k));
325 sum_vec = bf16_dot16(a_bf16, b_bf16, sum_vec);
328 float sum = _mm512_reduce_add_ps(sum_vec);
338 C[(size_t)i * N + j] = sum;
342 for (
int i = 0; i < M; ++i) {
343 for (
int j = 0; j < N; ++j) {
344 float sum = bias ? bias[j] : 0.0f;
345 for (
int k = 0; k < K; ++k) {
349 C[(size_t)i * N + j] = sum;
362 const uint16_t *bias,
366 if (!A || !B || !
C || M <= 0 || N <= 0 || K <= 0) {
370 #if defined(__AVX512F__)
371 #pragma omp parallel for
372 for (
int i = 0; i < M; ++i) {
375 for (; j <= N - 16; j += 16) {
376 __m512 b_vec = bias ? bf16x16_to_fp32(_mm256_loadu_si256((
const __m256i *)(bias + j)))
377 : _mm512_setzero_ps();
378 __m256i out = fp32x16_to_bf16(b_vec);
379 _mm256_storeu_si256((__m256i *)(
C + (
size_t)i * N + j), out);
387 for (
int k = 0; k < K; ++k) {
389 __m512 a_broadcast = _mm512_set1_ps(a_val);
392 for (; j <= N - 16; j += 16) {
393 __m256i b_bf16 = _mm256_loadu_si256((
const __m256i *)(B + (
size_t)k * N + j));
394 __m512 b_fp32 = bf16x16_to_fp32(b_bf16);
396 __m256i c_bf16 = _mm256_loadu_si256((
const __m256i *)(
C + (
size_t)i * N + j));
397 __m512 c_fp32 = bf16x16_to_fp32(c_bf16);
399 c_fp32 = _mm512_fmadd_ps(a_broadcast, b_fp32, c_fp32);
401 __m256i c_out = fp32x16_to_bf16(c_fp32);
402 _mm256_storeu_si256((__m256i *)(
C + (
size_t)i * N + j), c_out);
413 for (
int i = 0; i < M; ++i) {
414 for (
int j = 0; j < N; ++j) {
416 for (
int k = 0; k < K; ++k) {
429 const uint16_t *bias,
433 if (!A || !B || !
C || M <= 0 || N <= 0 || K <= 0) {
441 #if defined(__AVX512F__)
443 #pragma omp parallel for
444 for (
int i = 0; i < M; ++i) {
445 for (
int j = 0; j < N; ++j) {
452 #pragma omp parallel for
453 for (
int i = 0; i < M; ++i) {
454 for (
int j = 0; j < N; ++j) {
455 __m512 sum_vec = _mm512_setzero_ps();
458 for (; k <= K - 16; k += 16) {
460 __m512 a_fp32 = _mm512_setzero_ps();
461 for (
int kk = 0; kk < 16; ++kk) {
463 a_fp32 = _mm512_mask_mov_ps(a_fp32, 1 << kk, _mm512_set1_ps(val));
467 __m512 b_fp32 = _mm512_setzero_ps();
468 for (
int kk = 0; kk < 16; ++kk) {
470 b_fp32 = _mm512_mask_mov_ps(b_fp32, 1 << kk, _mm512_set1_ps(val));
473 sum_vec = _mm512_fmadd_ps(a_fp32, b_fp32, sum_vec);
476 float sum = _mm512_reduce_add_ps(sum_vec);
488 for (
int i = 0; i < M; ++i) {
489 for (
int j = 0; j < N; ++j) {
491 for (
int k = 0; k < K; ++k) {
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
void gemm_tn_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)
void gemm_blocked_serial_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
static int ck_min_i(int a, int b)
void gemm_nn_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)