40 #define GEMM_BACKEND "MKL"
41 #elif defined(USE_ONEDNN)
43 #define GEMM_BACKEND "oneDNN"
45 #define GEMM_BACKEND "Native"
48 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
49 #include <immintrin.h>
76 B_transposed ? CblasTrans : CblasNoTrans,
80 B, B_transposed ? K : N,
97 #elif defined(USE_ONEDNN)
104 static dnnl_engine_t g_engine = NULL;
105 static dnnl_stream_t g_stream = NULL;
107 static void onednn_init(
void) {
108 if (g_engine)
return;
109 dnnl_engine_create(&g_engine, dnnl_cpu, 0);
110 dnnl_stream_create(&g_stream, g_engine, dnnl_stream_default_flags);
124 dnnl_memory_desc_t a_md, b_md, c_md;
125 dnnl_dims_t a_dims = {M, K};
126 dnnl_dims_t b_dims = {K, N};
127 dnnl_dims_t c_dims = {M, N};
128 dnnl_dims_t a_strides = {K, 1};
129 dnnl_dims_t b_strides = {N, 1};
130 dnnl_dims_t c_strides = {N, 1};
145 dnnl_memory_desc_create_with_strides(&a_md, 2, a_dims, dnnl_f32, a_strides);
146 dnnl_memory_desc_create_with_strides(&b_md, 2, b_dims, dnnl_f32, b_strides);
147 dnnl_memory_desc_create_with_strides(&c_md, 2, c_dims, dnnl_f32, c_strides);
150 dnnl_primitive_desc_t matmul_pd;
151 dnnl_matmul_primitive_desc_create(&matmul_pd, g_engine, a_md, b_md, NULL, c_md, NULL);
154 dnnl_primitive_t matmul;
155 dnnl_primitive_create(&matmul, matmul_pd);
158 dnnl_memory_t a_mem, b_mem, c_mem;
159 dnnl_memory_create(&a_mem, a_md, g_engine, (
void*)A);
160 dnnl_memory_create(&b_mem, b_md, g_engine, (
void*)B);
161 dnnl_memory_create(&c_mem, c_md, g_engine, (
void*)
C);
164 dnnl_exec_arg_t args[3] = {
165 {DNNL_ARG_SRC, a_mem},
166 {DNNL_ARG_WEIGHTS, b_mem},
167 {DNNL_ARG_DST, c_mem}
169 dnnl_primitive_execute(matmul, g_stream, 3, args);
170 dnnl_stream_wait(g_stream);
173 dnnl_primitive_destroy(matmul);
174 dnnl_primitive_desc_destroy(matmul_pd);
175 dnnl_memory_destroy(a_mem);
176 dnnl_memory_destroy(b_mem);
177 dnnl_memory_destroy(c_mem);
178 dnnl_memory_desc_destroy(a_md);
179 dnnl_memory_desc_destroy(b_md);
180 dnnl_memory_desc_destroy(c_md);
205 #if defined(__AVX512F__)
210 #elif defined(__FMA__)
214 #elif defined(__AVX__)
226 #define MR (MR_FIXED)
227 #define NR (NR_FIXED)
228 #define MC (get_gemm_params()->MC)
229 #define NC (get_gemm_params()->NC)
230 #define KC (get_gemm_params()->KC)
244 #if defined(__AVX512F__)
245 static inline void gemm_microkernel_6x32_avx512(
247 const float * __restrict__ A,
int lda,
248 const float * __restrict__ B,
int ldb,
249 float * __restrict__
C,
int ldc,
254 __m512 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi;
255 __m512 c3_lo, c3_hi, c4_lo, c4_hi, c5_lo, c5_hi;
258 c0_lo = _mm512_setzero_ps(); c0_hi = _mm512_setzero_ps();
259 c1_lo = _mm512_setzero_ps(); c1_hi = _mm512_setzero_ps();
260 c2_lo = _mm512_setzero_ps(); c2_hi = _mm512_setzero_ps();
261 c3_lo = _mm512_setzero_ps(); c3_hi = _mm512_setzero_ps();
262 c4_lo = _mm512_setzero_ps(); c4_hi = _mm512_setzero_ps();
263 c5_lo = _mm512_setzero_ps(); c5_hi = _mm512_setzero_ps();
265 c0_lo = _mm512_loadu_ps(&
C[0 * ldc]); c0_hi = _mm512_loadu_ps(&
C[0 * ldc + 16]);
266 c1_lo = _mm512_loadu_ps(&
C[1 * ldc]); c1_hi = _mm512_loadu_ps(&
C[1 * ldc + 16]);
267 c2_lo = _mm512_loadu_ps(&
C[2 * ldc]); c2_hi = _mm512_loadu_ps(&
C[2 * ldc + 16]);
268 c3_lo = _mm512_loadu_ps(&
C[3 * ldc]); c3_hi = _mm512_loadu_ps(&
C[3 * ldc + 16]);
269 c4_lo = _mm512_loadu_ps(&
C[4 * ldc]); c4_hi = _mm512_loadu_ps(&
C[4 * ldc + 16]);
270 c5_lo = _mm512_loadu_ps(&
C[5 * ldc]); c5_hi = _mm512_loadu_ps(&
C[5 * ldc + 16]);
274 _mm_prefetch((
const char*)&B[0], _MM_HINT_T0);
275 _mm_prefetch((
const char*)&B[64], _MM_HINT_T0);
279 for (; k <= K - 8; k += 8) {
281 _mm_prefetch((
const char*)&B[(k + 16) * ldb], _MM_HINT_T0);
282 _mm_prefetch((
const char*)&B[(k + 16) * ldb + 64], _MM_HINT_T0);
283 _mm_prefetch((
const char*)&B[(k + 32) * ldb], _MM_HINT_T1);
286 __m512 b_lo_next = _mm512_loadu_ps(&B[k * ldb]);
287 __m512 b_hi_next = _mm512_loadu_ps(&B[k * ldb + 16]);
289 #define AVX512_ITER(koff) { \
290 __m512 b_lo = b_lo_next; \
291 __m512 b_hi = b_hi_next; \
293 b_lo_next = _mm512_loadu_ps(&B[(k + (koff) + 1) * ldb]); \
294 b_hi_next = _mm512_loadu_ps(&B[(k + (koff) + 1) * ldb + 16]); \
296 __m512 a0 = _mm512_set1_ps(A[0 * lda + k + (koff)]); \
297 __m512 a1 = _mm512_set1_ps(A[1 * lda + k + (koff)]); \
298 __m512 a2 = _mm512_set1_ps(A[2 * lda + k + (koff)]); \
299 __m512 a3 = _mm512_set1_ps(A[3 * lda + k + (koff)]); \
300 __m512 a4 = _mm512_set1_ps(A[4 * lda + k + (koff)]); \
301 __m512 a5 = _mm512_set1_ps(A[5 * lda + k + (koff)]); \
302 c0_lo = _mm512_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm512_fmadd_ps(a0, b_hi, c0_hi); \
303 c1_lo = _mm512_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm512_fmadd_ps(a1, b_hi, c1_hi); \
304 c2_lo = _mm512_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm512_fmadd_ps(a2, b_hi, c2_hi); \
305 c3_lo = _mm512_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm512_fmadd_ps(a3, b_hi, c3_hi); \
306 c4_lo = _mm512_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm512_fmadd_ps(a4, b_hi, c4_hi); \
307 c5_lo = _mm512_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm512_fmadd_ps(a5, b_hi, c5_hi); \
323 for (; k <= K - 4; k += 4) {
324 #define AVX512_ITER4(koff) { \
325 __m512 b_lo = _mm512_loadu_ps(&B[(k + koff) * ldb]); \
326 __m512 b_hi = _mm512_loadu_ps(&B[(k + koff) * ldb + 16]); \
327 __m512 a0 = _mm512_set1_ps(A[0 * lda + k + koff]); \
328 __m512 a1 = _mm512_set1_ps(A[1 * lda + k + koff]); \
329 __m512 a2 = _mm512_set1_ps(A[2 * lda + k + koff]); \
330 __m512 a3 = _mm512_set1_ps(A[3 * lda + k + koff]); \
331 __m512 a4 = _mm512_set1_ps(A[4 * lda + k + koff]); \
332 __m512 a5 = _mm512_set1_ps(A[5 * lda + k + koff]); \
333 c0_lo = _mm512_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm512_fmadd_ps(a0, b_hi, c0_hi); \
334 c1_lo = _mm512_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm512_fmadd_ps(a1, b_hi, c1_hi); \
335 c2_lo = _mm512_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm512_fmadd_ps(a2, b_hi, c2_hi); \
336 c3_lo = _mm512_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm512_fmadd_ps(a3, b_hi, c3_hi); \
337 c4_lo = _mm512_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm512_fmadd_ps(a4, b_hi, c4_hi); \
338 c5_lo = _mm512_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm512_fmadd_ps(a5, b_hi, c5_hi); \
349 __m512 b_lo = _mm512_loadu_ps(&B[k * ldb]);
350 __m512 b_hi = _mm512_loadu_ps(&B[k * ldb + 16]);
352 c0_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[0 * lda + k]), b_lo, c0_lo);
353 c0_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[0 * lda + k]), b_hi, c0_hi);
354 c1_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[1 * lda + k]), b_lo, c1_lo);
355 c1_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[1 * lda + k]), b_hi, c1_hi);
356 c2_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[2 * lda + k]), b_lo, c2_lo);
357 c2_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[2 * lda + k]), b_hi, c2_hi);
358 c3_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[3 * lda + k]), b_lo, c3_lo);
359 c3_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[3 * lda + k]), b_hi, c3_hi);
360 c4_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[4 * lda + k]), b_lo, c4_lo);
361 c4_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[4 * lda + k]), b_hi, c4_hi);
362 c5_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[5 * lda + k]), b_lo, c5_lo);
363 c5_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[5 * lda + k]), b_hi, c5_hi);
367 _mm512_storeu_ps(&
C[0 * ldc], c0_lo); _mm512_storeu_ps(&
C[0 * ldc + 16], c0_hi);
368 _mm512_storeu_ps(&
C[1 * ldc], c1_lo); _mm512_storeu_ps(&
C[1 * ldc + 16], c1_hi);
369 _mm512_storeu_ps(&
C[2 * ldc], c2_lo); _mm512_storeu_ps(&
C[2 * ldc + 16], c2_hi);
370 _mm512_storeu_ps(&
C[3 * ldc], c3_lo); _mm512_storeu_ps(&
C[3 * ldc + 16], c3_hi);
371 _mm512_storeu_ps(&
C[4 * ldc], c4_lo); _mm512_storeu_ps(&
C[4 * ldc + 16], c4_hi);
372 _mm512_storeu_ps(&
C[5 * ldc], c5_lo); _mm512_storeu_ps(&
C[5 * ldc + 16], c5_hi);
376 static inline void gemm_microkernel_6x32_packed_avx512(
378 const float * __restrict__ Ap,
379 const float * __restrict__ Bp,
380 float * __restrict__
C,
int ldc,
384 __m512 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi;
385 __m512 c3_lo, c3_hi, c4_lo, c4_hi, c5_lo, c5_hi;
388 c0_lo = _mm512_setzero_ps(); c0_hi = _mm512_setzero_ps();
389 c1_lo = _mm512_setzero_ps(); c1_hi = _mm512_setzero_ps();
390 c2_lo = _mm512_setzero_ps(); c2_hi = _mm512_setzero_ps();
391 c3_lo = _mm512_setzero_ps(); c3_hi = _mm512_setzero_ps();
392 c4_lo = _mm512_setzero_ps(); c4_hi = _mm512_setzero_ps();
393 c5_lo = _mm512_setzero_ps(); c5_hi = _mm512_setzero_ps();
395 c0_lo = _mm512_loadu_ps(&
C[0 * ldc]); c0_hi = _mm512_loadu_ps(&
C[0 * ldc + 16]);
396 c1_lo = _mm512_loadu_ps(&
C[1 * ldc]); c1_hi = _mm512_loadu_ps(&
C[1 * ldc + 16]);
397 c2_lo = _mm512_loadu_ps(&
C[2 * ldc]); c2_hi = _mm512_loadu_ps(&
C[2 * ldc + 16]);
398 c3_lo = _mm512_loadu_ps(&
C[3 * ldc]); c3_hi = _mm512_loadu_ps(&
C[3 * ldc + 16]);
399 c4_lo = _mm512_loadu_ps(&
C[4 * ldc]); c4_hi = _mm512_loadu_ps(&
C[4 * ldc + 16]);
400 c5_lo = _mm512_loadu_ps(&
C[5 * ldc]); c5_hi = _mm512_loadu_ps(&
C[5 * ldc + 16]);
404 _mm_prefetch((
const char*)Bp, _MM_HINT_T0);
405 _mm_prefetch((
const char*)(Bp + 16), _MM_HINT_T0);
408 for (; k <= K - 4; k += 4) {
409 _mm_prefetch((
const char*)(Bp + (k + 8) *
NR), _MM_HINT_T0);
410 _mm_prefetch((
const char*)(Bp + (k + 8) *
NR + 16), _MM_HINT_T0);
412 #define PACKED_ITER(koff) { \
413 __m512 b_lo = _mm512_load_ps(&Bp[(k + koff) * NR]); \
414 __m512 b_hi = _mm512_load_ps(&Bp[(k + koff) * NR + 16]); \
415 __m512 a0 = _mm512_set1_ps(Ap[0 * K + k + koff]); \
416 __m512 a1 = _mm512_set1_ps(Ap[1 * K + k + koff]); \
417 __m512 a2 = _mm512_set1_ps(Ap[2 * K + k + koff]); \
418 __m512 a3 = _mm512_set1_ps(Ap[3 * K + k + koff]); \
419 __m512 a4 = _mm512_set1_ps(Ap[4 * K + k + koff]); \
420 __m512 a5 = _mm512_set1_ps(Ap[5 * K + k + koff]); \
421 c0_lo = _mm512_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm512_fmadd_ps(a0, b_hi, c0_hi); \
422 c1_lo = _mm512_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm512_fmadd_ps(a1, b_hi, c1_hi); \
423 c2_lo = _mm512_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm512_fmadd_ps(a2, b_hi, c2_hi); \
424 c3_lo = _mm512_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm512_fmadd_ps(a3, b_hi, c3_hi); \
425 c4_lo = _mm512_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm512_fmadd_ps(a4, b_hi, c4_hi); \
426 c5_lo = _mm512_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm512_fmadd_ps(a5, b_hi, c5_hi); \
438 __m512 b_lo = _mm512_load_ps(&Bp[k *
NR]);
439 __m512 b_hi = _mm512_load_ps(&Bp[k *
NR + 16]);
441 c0_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[0 * K + k]), b_lo, c0_lo);
442 c0_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[0 * K + k]), b_hi, c0_hi);
443 c1_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[1 * K + k]), b_lo, c1_lo);
444 c1_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[1 * K + k]), b_hi, c1_hi);
445 c2_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[2 * K + k]), b_lo, c2_lo);
446 c2_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[2 * K + k]), b_hi, c2_hi);
447 c3_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[3 * K + k]), b_lo, c3_lo);
448 c3_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[3 * K + k]), b_hi, c3_hi);
449 c4_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[4 * K + k]), b_lo, c4_lo);
450 c4_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[4 * K + k]), b_hi, c4_hi);
451 c5_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[5 * K + k]), b_lo, c5_lo);
452 c5_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[5 * K + k]), b_hi, c5_hi);
455 _mm512_storeu_ps(&
C[0 * ldc], c0_lo); _mm512_storeu_ps(&
C[0 * ldc + 16], c0_hi);
456 _mm512_storeu_ps(&
C[1 * ldc], c1_lo); _mm512_storeu_ps(&
C[1 * ldc + 16], c1_hi);
457 _mm512_storeu_ps(&
C[2 * ldc], c2_lo); _mm512_storeu_ps(&
C[2 * ldc + 16], c2_hi);
458 _mm512_storeu_ps(&
C[3 * ldc], c3_lo); _mm512_storeu_ps(&
C[3 * ldc + 16], c3_hi);
459 _mm512_storeu_ps(&
C[4 * ldc], c4_lo); _mm512_storeu_ps(&
C[4 * ldc + 16], c4_hi);
460 _mm512_storeu_ps(&
C[5 * ldc], c5_lo); _mm512_storeu_ps(&
C[5 * ldc + 16], c5_hi);
473 static inline void gemm_microkernel_6x16_avx(
475 const float * __restrict__ A,
int lda,
476 const float * __restrict__ B,
int ldb,
477 float * __restrict__
C,
int ldc,
482 __m256 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi;
483 __m256 c3_lo, c3_hi, c4_lo, c4_hi, c5_lo, c5_hi;
486 c0_lo = _mm256_setzero_ps(); c0_hi = _mm256_setzero_ps();
487 c1_lo = _mm256_setzero_ps(); c1_hi = _mm256_setzero_ps();
488 c2_lo = _mm256_setzero_ps(); c2_hi = _mm256_setzero_ps();
489 c3_lo = _mm256_setzero_ps(); c3_hi = _mm256_setzero_ps();
490 c4_lo = _mm256_setzero_ps(); c4_hi = _mm256_setzero_ps();
491 c5_lo = _mm256_setzero_ps(); c5_hi = _mm256_setzero_ps();
493 c0_lo = _mm256_loadu_ps(&
C[0 * ldc]); c0_hi = _mm256_loadu_ps(&
C[0 * ldc + 8]);
494 c1_lo = _mm256_loadu_ps(&
C[1 * ldc]); c1_hi = _mm256_loadu_ps(&
C[1 * ldc + 8]);
495 c2_lo = _mm256_loadu_ps(&
C[2 * ldc]); c2_hi = _mm256_loadu_ps(&
C[2 * ldc + 8]);
496 c3_lo = _mm256_loadu_ps(&
C[3 * ldc]); c3_hi = _mm256_loadu_ps(&
C[3 * ldc + 8]);
497 c4_lo = _mm256_loadu_ps(&
C[4 * ldc]); c4_hi = _mm256_loadu_ps(&
C[4 * ldc + 8]);
498 c5_lo = _mm256_loadu_ps(&
C[5 * ldc]); c5_hi = _mm256_loadu_ps(&
C[5 * ldc + 8]);
502 _mm_prefetch((
const char*)B, _MM_HINT_T0);
503 _mm_prefetch((
const char*)(B + 32), _MM_HINT_T0);
510 for (; k <= K - 8; k += 8) {
512 _mm_prefetch((
const char*)&B[(k + 16) * ldb], _MM_HINT_T0);
513 _mm_prefetch((
const char*)&B[(k + 16) * ldb + 32], _MM_HINT_T0);
514 _mm_prefetch((
const char*)&B[(k + 17) * ldb], _MM_HINT_T0);
517 __m256 b_lo_next = _mm256_loadu_ps(&B[k * ldb]);
518 __m256 b_hi_next = _mm256_loadu_ps(&B[k * ldb + 8]);
520 #define FMA_ITER(koff) { \
521 __m256 b_lo = b_lo_next; \
522 __m256 b_hi = b_hi_next; \
524 b_lo_next = _mm256_loadu_ps(&B[(k + (koff) + 1) * ldb]); \
525 b_hi_next = _mm256_loadu_ps(&B[(k + (koff) + 1) * ldb + 8]); \
527 __m256 a0 = _mm256_set1_ps(A[0 * lda + k + (koff)]); \
528 __m256 a1 = _mm256_set1_ps(A[1 * lda + k + (koff)]); \
529 __m256 a2 = _mm256_set1_ps(A[2 * lda + k + (koff)]); \
530 __m256 a3 = _mm256_set1_ps(A[3 * lda + k + (koff)]); \
531 __m256 a4 = _mm256_set1_ps(A[4 * lda + k + (koff)]); \
532 __m256 a5 = _mm256_set1_ps(A[5 * lda + k + (koff)]); \
533 c0_lo = _mm256_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm256_fmadd_ps(a0, b_hi, c0_hi); \
534 c1_lo = _mm256_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm256_fmadd_ps(a1, b_hi, c1_hi); \
535 c2_lo = _mm256_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm256_fmadd_ps(a2, b_hi, c2_hi); \
536 c3_lo = _mm256_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm256_fmadd_ps(a3, b_hi, c3_hi); \
537 c4_lo = _mm256_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm256_fmadd_ps(a4, b_hi, c4_hi); \
538 c5_lo = _mm256_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm256_fmadd_ps(a5, b_hi, c5_hi); \
555 __m256 b_lo = _mm256_loadu_ps(&B[k * ldb]);
556 __m256 b_hi = _mm256_loadu_ps(&B[k * ldb + 8]);
558 __m256 a0 = _mm256_set1_ps(A[0 * lda + k]);
559 __m256 a1 = _mm256_set1_ps(A[1 * lda + k]);
560 __m256 a2 = _mm256_set1_ps(A[2 * lda + k]);
561 __m256 a3 = _mm256_set1_ps(A[3 * lda + k]);
562 __m256 a4 = _mm256_set1_ps(A[4 * lda + k]);
563 __m256 a5 = _mm256_set1_ps(A[5 * lda + k]);
565 c0_lo = _mm256_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm256_fmadd_ps(a0, b_hi, c0_hi);
566 c1_lo = _mm256_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm256_fmadd_ps(a1, b_hi, c1_hi);
567 c2_lo = _mm256_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm256_fmadd_ps(a2, b_hi, c2_hi);
568 c3_lo = _mm256_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm256_fmadd_ps(a3, b_hi, c3_hi);
569 c4_lo = _mm256_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm256_fmadd_ps(a4, b_hi, c4_hi);
570 c5_lo = _mm256_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm256_fmadd_ps(a5, b_hi, c5_hi);
574 for (; k <= K - 4; k += 4) {
575 _mm_prefetch((
const char*)&B[(k + 8) * ldb], _MM_HINT_T0);
577 #define AVX_ITER(koff) { \
578 __m256 b_lo = _mm256_loadu_ps(&B[(k + koff) * ldb]); \
579 __m256 b_hi = _mm256_loadu_ps(&B[(k + koff) * ldb + 8]); \
580 __m256 a0 = _mm256_set1_ps(A[0 * lda + k + koff]); \
581 __m256 a1 = _mm256_set1_ps(A[1 * lda + k + koff]); \
582 __m256 a2 = _mm256_set1_ps(A[2 * lda + k + koff]); \
583 __m256 a3 = _mm256_set1_ps(A[3 * lda + k + koff]); \
584 __m256 a4 = _mm256_set1_ps(A[4 * lda + k + koff]); \
585 __m256 a5 = _mm256_set1_ps(A[5 * lda + k + koff]); \
586 c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo)); \
587 c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi)); \
588 c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo)); \
589 c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi)); \
590 c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo)); \
591 c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi)); \
592 c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo)); \
593 c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi)); \
594 c4_lo = _mm256_add_ps(c4_lo, _mm256_mul_ps(a4, b_lo)); \
595 c4_hi = _mm256_add_ps(c4_hi, _mm256_mul_ps(a4, b_hi)); \
596 c5_lo = _mm256_add_ps(c5_lo, _mm256_mul_ps(a5, b_lo)); \
597 c5_hi = _mm256_add_ps(c5_hi, _mm256_mul_ps(a5, b_hi)); \
609 __m256 b_lo = _mm256_loadu_ps(&B[k * ldb]);
610 __m256 b_hi = _mm256_loadu_ps(&B[k * ldb + 8]);
612 __m256 a0 = _mm256_set1_ps(A[0 * lda + k]);
613 __m256 a1 = _mm256_set1_ps(A[1 * lda + k]);
614 __m256 a2 = _mm256_set1_ps(A[2 * lda + k]);
615 __m256 a3 = _mm256_set1_ps(A[3 * lda + k]);
616 __m256 a4 = _mm256_set1_ps(A[4 * lda + k]);
617 __m256 a5 = _mm256_set1_ps(A[5 * lda + k]);
619 c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo));
620 c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi));
621 c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo));
622 c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi));
623 c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo));
624 c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi));
625 c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo));
626 c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi));
627 c4_lo = _mm256_add_ps(c4_lo, _mm256_mul_ps(a4, b_lo));
628 c4_hi = _mm256_add_ps(c4_hi, _mm256_mul_ps(a4, b_hi));
629 c5_lo = _mm256_add_ps(c5_lo, _mm256_mul_ps(a5, b_lo));
630 c5_hi = _mm256_add_ps(c5_hi, _mm256_mul_ps(a5, b_hi));
634 _mm256_storeu_ps(&
C[0 * ldc], c0_lo); _mm256_storeu_ps(&
C[0 * ldc + 8], c0_hi);
635 _mm256_storeu_ps(&
C[1 * ldc], c1_lo); _mm256_storeu_ps(&
C[1 * ldc + 8], c1_hi);
636 _mm256_storeu_ps(&
C[2 * ldc], c2_lo); _mm256_storeu_ps(&
C[2 * ldc + 8], c2_hi);
637 _mm256_storeu_ps(&
C[3 * ldc], c3_lo); _mm256_storeu_ps(&
C[3 * ldc + 8], c3_hi);
638 _mm256_storeu_ps(&
C[4 * ldc], c4_lo); _mm256_storeu_ps(&
C[4 * ldc + 8], c4_hi);
639 _mm256_storeu_ps(&
C[5 * ldc], c5_lo); _mm256_storeu_ps(&
C[5 * ldc + 8], c5_hi);
650 #if defined(__AVX__) && !defined(__FMA__)
651 static inline void gemm_microkernel_4x16_avx(
653 const float * __restrict__ A,
int lda,
654 const float * __restrict__ B,
int ldb,
655 float * __restrict__
C,
int ldc,
660 __m256 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi, c3_lo, c3_hi;
663 c0_lo = _mm256_setzero_ps(); c0_hi = _mm256_setzero_ps();
664 c1_lo = _mm256_setzero_ps(); c1_hi = _mm256_setzero_ps();
665 c2_lo = _mm256_setzero_ps(); c2_hi = _mm256_setzero_ps();
666 c3_lo = _mm256_setzero_ps(); c3_hi = _mm256_setzero_ps();
668 c0_lo = _mm256_loadu_ps(&
C[0 * ldc]); c0_hi = _mm256_loadu_ps(&
C[0 * ldc + 8]);
669 c1_lo = _mm256_loadu_ps(&
C[1 * ldc]); c1_hi = _mm256_loadu_ps(&
C[1 * ldc + 8]);
670 c2_lo = _mm256_loadu_ps(&
C[2 * ldc]); c2_hi = _mm256_loadu_ps(&
C[2 * ldc + 8]);
671 c3_lo = _mm256_loadu_ps(&
C[3 * ldc]); c3_hi = _mm256_loadu_ps(&
C[3 * ldc + 8]);
674 _mm_prefetch((
const char*)B, _MM_HINT_T0);
678 for (; k <= K - 4; k += 4) {
679 _mm_prefetch((
const char*)&B[(k + 8) * ldb], _MM_HINT_T0);
681 #define AVX4_ITER(koff) { \
682 __m256 b_lo = _mm256_loadu_ps(&B[(k + koff) * ldb]); \
683 __m256 b_hi = _mm256_loadu_ps(&B[(k + koff) * ldb + 8]); \
684 __m256 a0 = _mm256_set1_ps(A[0 * lda + k + koff]); \
685 __m256 a1 = _mm256_set1_ps(A[1 * lda + k + koff]); \
686 __m256 a2 = _mm256_set1_ps(A[2 * lda + k + koff]); \
687 __m256 a3 = _mm256_set1_ps(A[3 * lda + k + koff]); \
688 c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo)); \
689 c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi)); \
690 c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo)); \
691 c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi)); \
692 c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo)); \
693 c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi)); \
694 c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo)); \
695 c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi)); \
708 __m256 b_lo = _mm256_loadu_ps(&B[k * ldb]);
709 __m256 b_hi = _mm256_loadu_ps(&B[k * ldb + 8]);
711 __m256 a0 = _mm256_set1_ps(A[0 * lda + k]);
712 __m256 a1 = _mm256_set1_ps(A[1 * lda + k]);
713 __m256 a2 = _mm256_set1_ps(A[2 * lda + k]);
714 __m256 a3 = _mm256_set1_ps(A[3 * lda + k]);
716 c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo));
717 c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi));
718 c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo));
719 c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi));
720 c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo));
721 c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi));
722 c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo));
723 c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi));
726 _mm256_storeu_ps(&
C[0 * ldc], c0_lo); _mm256_storeu_ps(&
C[0 * ldc + 8], c0_hi);
727 _mm256_storeu_ps(&
C[1 * ldc], c1_lo); _mm256_storeu_ps(&
C[1 * ldc + 8], c1_hi);
728 _mm256_storeu_ps(&
C[2 * ldc], c2_lo); _mm256_storeu_ps(&
C[2 * ldc + 8], c2_hi);
729 _mm256_storeu_ps(&
C[3 * ldc], c3_lo); _mm256_storeu_ps(&
C[3 * ldc + 8], c3_hi);
739 const float *A,
int lda,
740 const float *B,
int ldb,
745 for (
int i = 0; i < m; i++) {
746 for (
int j = 0; j < n; j++) {
747 float sum = first_k ? 0.0f :
C[i * ldc + j];
748 for (
int k = 0; k < K; k++) {
749 sum += A[i * lda + k] * B[k * ldb + j];
751 C[i * ldc + j] = sum;
762 const float *A,
int lda,
764 int mc,
int kc,
int mr
767 #pragma omp parallel for schedule(static) if(mc > 64)
768 for (
int i = 0; i < mc; i += mr) {
769 int rows = (i + mr <= mc) ? mr : (mc - i);
770 float *Ap_panel = &Ap[(i / mr) * mr * kc];
772 for (
int p = 0; p < rows; p++) {
773 const float *A_row = &A[(i + p) * lda];
774 float *Ap_row = &Ap_panel[p * kc];
779 for (; k <= kc - 8; k += 8) {
780 _mm256_storeu_ps(&Ap_row[k], _mm256_loadu_ps(&A_row[k]));
783 for (; k < kc; k++) {
784 Ap_row[k] = A_row[k];
788 for (
int p = rows; p < mr; p++) {
789 memset(&Ap_panel[p * kc], 0, kc *
sizeof(
float));
796 const float *B,
int ldb,
798 int kc,
int nc,
int nr
801 #pragma omp parallel for schedule(static) if(nc > 128)
802 for (
int j = 0; j < nc; j += nr) {
803 int cols = (j + nr <= nc) ? nr : (nc - j);
804 float *Bp_panel = &Bp[(j / nr) * nr * kc];
806 for (
int k = 0; k < kc; k++) {
807 const float *B_row = &B[k * ldb + j];
808 float *Bp_row = &Bp_panel[k * nr];
812 #if defined(__AVX512F__)
813 for (; c <= cols - 16; c += 16) {
814 _mm512_store_ps(&Bp_row[c], _mm512_loadu_ps(&B_row[c]));
816 #elif defined(__AVX__)
817 for (; c <= cols - 8; c += 8) {
818 _mm256_store_ps(&Bp_row[c], _mm256_loadu_ps(&B_row[c]));
821 for (; c < cols; c++) {
822 Bp_row[c] = B_row[c];
824 for (; c < nr; c++) {
870 for (
int i = 0; i < M; i++) {
871 memset(&
C[i * N], 0, N *
sizeof(
float));
878 for (
int k0 = 0; k0 < K; k0 +=
KC) {
879 int kb = (k0 +
KC <= K) ?
KC : (K - k0);
880 int first_k = (k0 == 0);
883 for (
int m0 = 0; m0 < M; m0 += mr) {
884 int mr_actual = (m0 + mr <= M) ? mr : (M - m0);
886 for (
int n0 = 0; n0 < N; n0 += nr) {
887 int nr_actual = (n0 + nr <= N) ? nr : (N - n0);
889 const float *A_tile = &A[m0 * K + k0];
890 const float *B_tile = &B[k0 * N + n0];
891 float *C_tile = &
C[m0 * N + n0];
893 if (mr_actual == mr && nr_actual == nr) {
894 #if defined(__AVX512F__)
895 gemm_microkernel_6x32_avx512(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
896 #elif defined(__FMA__)
897 gemm_microkernel_6x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
898 #elif defined(__AVX__)
899 gemm_microkernel_4x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
923 if (physical_cores > 0) {
924 int current_max = omp_get_max_threads();
926 if (current_max > physical_cores) {
927 omp_set_num_threads(physical_cores);
946 if ((
size_t)M * N * K <= 512ULL * 512 * 512) {
955 #pragma omp parallel for schedule(static)
956 for (
int i = 0; i < M; i++) {
957 memset(&
C[i * N], 0, N *
sizeof(
float));
961 for (
int k0 = 0; k0 < K; k0 +=
KC) {
962 int kb = (k0 +
KC <= K) ?
KC : (K - k0);
963 int first_k = (k0 == 0);
967 #pragma omp parallel for schedule(static)
968 for (
int m0 = 0; m0 < M; m0 += mr) {
969 int mr_actual = (m0 + mr <= M) ? mr : (M - m0);
972 for (
int n0 = 0; n0 < N; n0 += nr) {
973 int nr_actual = (n0 + nr <= N) ? nr : (N - n0);
975 const float *A_tile = &A[m0 * K + k0];
976 const float *B_tile = &B[k0 * N + n0];
977 float *C_tile = &
C[m0 * N + n0];
979 if (mr_actual == mr && nr_actual == nr) {
980 #if defined(__AVX512F__)
981 gemm_microkernel_6x32_avx512(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
982 #elif defined(__FMA__)
983 gemm_microkernel_6x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
984 #elif defined(__AVX__)
985 gemm_microkernel_4x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
1001 #if defined(__AVX512F__)
1002 static inline void gemm_microkernel_6x32_bt_avx512(
1004 const float * __restrict__ A,
int lda,
1005 const float * __restrict__ B,
int ldb,
1006 float * __restrict__
C,
int ldc,
1014 for (
int i = 0; i < 6; i++) {
1015 for (
int j = 0; j < 32; j++) {
1016 C[i * ldc + j] = 0.0f;
1023 for (; k <= K - 16; k += 16) {
1025 __m512 a0 = _mm512_loadu_ps(&A[0 * lda + k]);
1026 __m512 a1 = _mm512_loadu_ps(&A[1 * lda + k]);
1027 __m512 a2 = _mm512_loadu_ps(&A[2 * lda + k]);
1028 __m512 a3 = _mm512_loadu_ps(&A[3 * lda + k]);
1029 __m512 a4 = _mm512_loadu_ps(&A[4 * lda + k]);
1030 __m512 a5 = _mm512_loadu_ps(&A[5 * lda + k]);
1033 for (
int j = 0; j < 32; j++) {
1034 __m512 b = _mm512_loadu_ps(&B[j * ldb + k]);
1037 C[0 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a0, b));
1038 C[1 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a1, b));
1039 C[2 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a2, b));
1040 C[3 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a3, b));
1041 C[4 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a4, b));
1042 C[5 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a5, b));
1047 for (; k < K; k++) {
1048 for (
int i = 0; i < 6; i++) {
1049 float a = A[i * lda + k];
1050 for (
int j = 0; j < 32; j++) {
1051 C[i * ldc + j] += a * B[j * ldb + k];
1066 #pragma omp parallel for schedule(static)
1067 for (
int i = 0; i < M; i++) {
1068 memset(&
C[i * N], 0, N *
sizeof(
float));
1074 #pragma omp parallel for schedule(dynamic) collapse(2)
1075 for (
int m0 = 0; m0 < M; m0 +=
MC) {
1076 for (
int n0 = 0; n0 < N; n0 +=
NC) {
1077 int mb = (m0 +
MC <= M) ?
MC : (M - m0);
1078 int nb = (n0 +
NC <= N) ?
NC : (N - n0);
1080 for (
int k0 = 0; k0 < K; k0 +=
KC) {
1081 int kb = (k0 +
KC <= K) ?
KC : (K - k0);
1082 int first_k = (k0 == 0);
1084 for (
int m1 = 0; m1 < mb; m1 += mr) {
1085 int mr_actual = (m1 + mr <= mb) ? mr : (mb - m1);
1087 for (
int n1 = 0; n1 < nb; n1 += nr) {
1088 int nr_actual = (n1 + nr <= nb) ? nr : (nb - n1);
1090 const float *A_tile = &A[(m0 + m1) * K + k0];
1091 const float *B_tile = &B[(n0 + n1) * K + k0];
1092 float *C_tile = &
C[(m0 + m1) * N + (n0 + n1)];
1094 if (mr_actual == mr && nr_actual == nr) {
1095 #if defined(__AVX512F__)
1096 gemm_microkernel_6x32_bt_avx512(kb, A_tile, K, B_tile, K, C_tile, N, first_k);
1099 for (
int i = 0; i < mr; i++) {
1100 for (
int j = 0; j < nr; j++) {
1101 float sum = first_k ? 0.0f : C_tile[i * N + j];
1102 for (
int kk = 0; kk < kb; kk++) {
1103 sum += A_tile[i * K + kk] * B_tile[j * K + kk];
1105 C_tile[i * N + j] = sum;
1111 for (
int i = 0; i < mr_actual; i++) {
1112 for (
int j = 0; j < nr_actual; j++) {
1113 float sum = first_k ? 0.0f : C_tile[i * N + j];
1114 for (
int kk = 0; kk < kb; kk++) {
1115 sum += A_tile[i * K + kk] * B_tile[j * K + kk];
1117 C_tile[i * N + j] = sum;
1132 #define PACK_THRESHOLD 256
1138 int M,
int N,
int K,
const CPUInfo * get_cpu_info(void)
const char * gemm_get_backend(void)
static void gemm_microkernel_edge(int m, int n, int K, const float *A, int lda, const float *B, int ldb, float *C, int ldc, int first_k)
static void gemm_microkernel_sequential(const float *A, const float *B, float *C, int M, int N, int K)
static void gemm_init_threads(void)
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
static void pack_b_panel(const float *B, int ldb, float *Bp, int kc, int nc, int nr)
static void pack_a_panel(const float *A, int lda, float *Ap, int mc, int kc, int mr)
void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K)
static int g_threads_initialized
void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K)
void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K)