21 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
22 #include <immintrin.h>
26 static inline int ck_min(
int a,
int b) {
return a < b ? a : b; }
33 #pragma omp parallel for schedule(static)
34 for (
int i = 0; i < M; ++i) {
35 float *c_row =
C + (size_t)i * (
size_t)N;
36 for (
int j = 0; j < N; ++j) {
43 #if defined(__AVX__) && !defined(__AVX512F__)
44 static inline float hsum256_ps(__m256 v) {
46 __m128 lo = _mm256_castps256_ps128(v);
47 __m128 hi = _mm256_extractf128_ps(v, 1);
48 __m128 sum128 = _mm_add_ps(lo, hi);
50 __m128 shuf = _mm_movehdup_ps(sum128);
51 __m128 sums = _mm_add_ps(sum128, shuf);
52 shuf = _mm_movehl_ps(shuf, sums);
53 sums = _mm_add_ss(sums, shuf);
54 return _mm_cvtss_f32(sums);
68 #pragma omp parallel for schedule(static)
69 for (
int j = 0; j < N; ++j) {
70 const float *b_row = B + (size_t)j * (
size_t)K;
71 float sum = bias ? bias[j] : 0.0f;
73 #if defined(__AVX512F__)
74 __m512 acc = _mm512_setzero_ps();
76 for (; k <= K - 16; k += 16) {
77 __m512 a_vec = _mm512_loadu_ps(A + k);
78 __m512 b_vec = _mm512_loadu_ps(b_row + k);
79 acc = _mm512_fmadd_ps(a_vec, b_vec, acc);
81 sum += _mm512_reduce_add_ps(acc);
83 sum += A[k] * b_row[k];
85 #elif defined(__AVX__)
86 __m256 acc = _mm256_setzero_ps();
88 for (; k <= K - 8; k += 8) {
89 __m256 a_vec = _mm256_loadu_ps(A + k);
90 __m256 b_vec = _mm256_loadu_ps(b_row + k);
91 acc = _mm256_add_ps(acc, _mm256_mul_ps(a_vec, b_vec));
93 sum += hsum256_ps(acc);
95 sum += A[k] * b_row[k];
98 for (
int k = 0; k < K; ++k) {
99 sum += A[k] * b_row[k];
113 for (
int i = 0; i < M; i++) {
114 for (
int j = 0; j < N; j++) {
115 double sum = bias ? (double)bias[j] : 0.0;
116 for (
int k = 0; k < K; k++) {
117 sum += (double)A[i * K + k] * (
double)B[j * K + k];
119 C[i * N + j] = (float)sum;
135 #pragma omp parallel for
136 for (
int i = 0; i < M; i++) {
137 for (
int j = 0; j < N; j++) {
139 for (
int k = 0; k < K; k++) {
140 sum += A[i * K + k] * B[j * K + k];
142 float bias_val = bias ? bias[j] : 0.0f;
143 C[i * N + j] = sum + bias_val;
159 #if defined(__AVX512F__)
160 #pragma omp parallel for
161 for (
int i = 0; i < M; i++) {
162 for (
int j = 0; j < N; j++) {
163 __m512 sum_vec = _mm512_setzero_ps();
165 for (k = 0; k <= K - 16; k += 16) {
166 __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
167 __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
168 sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
170 float sum = _mm512_reduce_add_ps(sum_vec);
172 sum += A[i * K + k] * B[j * K + k];
174 float bias_val = bias ? bias[j] : 0.0f;
175 C[i * N + j] = sum + bias_val;
178 #elif defined(__AVX__)
180 #pragma omp parallel for
181 for (
int i = 0; i < M; i++) {
182 for (
int j = 0; j < N; j++) {
183 __m256 sum_vec = _mm256_setzero_ps();
185 for (k = 0; k <= K - 8; k += 8) {
186 __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
187 __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
188 __m256 prod = _mm256_mul_ps(a_vec, b_vec);
189 sum_vec = _mm256_add_ps(sum_vec, prod);
191 float sum = hsum256_ps(sum_vec);
193 sum += A[i * K + k] * B[j * K + k];
195 float bias_val = bias ? bias[j] : 0.0f;
196 C[i * N + j] = sum + bias_val;
215 #if defined(__AVX512F__)
216 const int block_size = 64;
217 #pragma omp parallel for
218 for (
int i = 0; i < M; i++) {
219 for (
int j = 0; j < N; j++) {
220 C[i * N + j] = bias ? bias[j] : 0.0f;
223 #pragma omp parallel for collapse(3)
224 for (
int ii = 0; ii < M; ii += block_size) {
225 for (
int jj = 0; jj < N; jj += block_size) {
226 for (
int kk = 0; kk < K; kk += block_size) {
227 int i_end =
ck_min(ii + block_size, M);
228 int j_end =
ck_min(jj + block_size, N);
229 int k_end =
ck_min(kk + block_size, K);
231 for (
int i = ii; i < i_end; i++) {
232 for (
int j = jj; j < j_end; j++) {
233 __m512 sum_vec = _mm512_setzero_ps();
235 for (k = kk; k <= k_end - 16; k += 16) {
236 __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
237 __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
238 sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
240 float partial_sum = _mm512_reduce_add_ps(sum_vec);
241 for (; k < k_end; k++) {
242 partial_sum += A[i * K + k] * B[j * K + k];
245 C[i * N + j] += partial_sum;
251 #elif defined(__AVX__)
253 const int block_size = 32;
254 #pragma omp parallel for
255 for (
int i = 0; i < M; i++) {
256 for (
int j = 0; j < N; j++) {
257 C[i * N + j] = bias ? bias[j] : 0.0f;
260 #pragma omp parallel for collapse(3)
261 for (
int ii = 0; ii < M; ii += block_size) {
262 for (
int jj = 0; jj < N; jj += block_size) {
263 for (
int kk = 0; kk < K; kk += block_size) {
264 int i_end =
ck_min(ii + block_size, M);
265 int j_end =
ck_min(jj + block_size, N);
266 int k_end =
ck_min(kk + block_size, K);
268 for (
int i = ii; i < i_end; i++) {
269 for (
int j = jj; j < j_end; j++) {
270 __m256 sum_vec = _mm256_setzero_ps();
272 for (k = kk; k <= k_end - 8; k += 8) {
273 __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
274 __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
275 __m256 prod = _mm256_mul_ps(a_vec, b_vec);
276 sum_vec = _mm256_add_ps(sum_vec, prod);
278 float partial_sum = hsum256_ps(sum_vec);
279 for (; k < k_end; k++) {
280 partial_sum += A[i * K + k] * B[j * K + k];
283 C[i * N + j] += partial_sum;
306 for (
int i = 0; i < M; i++) {
307 for (
int j = 0; j < N; j++) {
308 double sum = bias ? (double)bias[j] : 0.0;
309 for (
int k = 0; k < K; k++) {
310 sum += (double)A[i * K + k] * (
double)B[k * N + j];
312 C[i * N + j] = (float)sum;
327 #pragma omp parallel for
328 for (
int i = 0; i < M; i++) {
329 for (
int j = 0; j < N; j++) {
330 float sum = bias ? bias[j] : 0.0f;
331 for (
int k = 0; k < K; k++) {
332 sum += A[i * K + k] * B[k * N + j];
349 #if defined(__AVX512F__)
352 #pragma omp parallel for
353 for (
int i = 0; i < M; i++) {
356 for (; j <= N - 16; j += 16) {
357 __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
358 for (
int k = 0; k < K; k++) {
359 __m512 a_broadcast = _mm512_set1_ps(A[i * K + k]);
360 __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
361 sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
363 _mm512_storeu_ps(&
C[i * N + j], sum_vec);
367 float sum = bias ? bias[j] : 0.0f;
368 for (
int k = 0; k < K; k++) {
369 sum += A[i * K + k] * B[k * N + j];
374 #elif defined(__AVX__)
376 #pragma omp parallel for
377 for (
int i = 0; i < M; i++) {
379 for (; j <= N - 8; j += 8) {
380 __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
381 for (
int k = 0; k < K; k++) {
382 __m256 a_broadcast = _mm256_set1_ps(A[i * K + k]);
383 __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
384 __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
385 sum_vec = _mm256_add_ps(sum_vec, prod);
387 _mm256_storeu_ps(&
C[i * N + j], sum_vec);
390 float sum = bias ? bias[j] : 0.0f;
391 for (
int k = 0; k < K; k++) {
392 sum += A[i * K + k] * B[k * N + j];
412 #if defined(__AVX512F__)
413 const int block_size = 64;
414 #elif defined(__AVX__)
415 const int block_size = 32;
417 const int block_size = 32;
420 #pragma omp parallel for
421 for (
int i = 0; i < M; i++) {
422 for (
int j = 0; j < N; j++) {
423 C[i * N + j] = bias ? bias[j] : 0.0f;
427 #pragma omp parallel for
428 for (
int ii = 0; ii < M; ii += block_size) {
429 for (
int kk = 0; kk < K; kk += block_size) {
430 for (
int jj = 0; jj < N; jj += block_size) {
431 int i_end =
ck_min(ii + block_size, M);
432 int k_end =
ck_min(kk + block_size, K);
433 int j_end =
ck_min(jj + block_size, N);
435 for (
int i = ii; i < i_end; i++) {
436 for (
int k = kk; k < k_end; k++) {
437 float a_val = A[i * K + k];
438 #if defined(__AVX512F__)
439 __m512 a_broadcast = _mm512_set1_ps(a_val);
441 for (j = jj; j <= j_end - 16; j += 16) {
442 __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
443 __m512 c_vec = _mm512_loadu_ps(&
C[i * N + j]);
444 c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
445 _mm512_storeu_ps(&
C[i * N + j], c_vec);
447 for (; j < j_end; j++) {
448 C[i * N + j] += a_val * B[k * N + j];
450 #elif defined(__AVX__)
451 __m256 a_broadcast = _mm256_set1_ps(a_val);
453 for (j = jj; j <= j_end - 8; j += 8) {
454 __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
455 __m256 c_vec = _mm256_loadu_ps(&
C[i * N + j]);
456 __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
457 c_vec = _mm256_add_ps(c_vec, prod);
458 _mm256_storeu_ps(&
C[i * N + j], c_vec);
460 for (; j < j_end; j++) {
461 C[i * N + j] += a_val * B[k * N + j];
464 for (
int j = jj; j < j_end; j++) {
465 C[i * N + j] += a_val * B[k * N + j];
487 for (
int i = 0; i < M; i++) {
488 for (
int j = 0; j < N; j++) {
489 double sum = bias ? (double)bias[j] : 0.0;
490 for (
int k = 0; k < K; k++) {
492 sum += (double)A[k * M + i] * (
double)B[k * N + j];
494 C[i * N + j] = (float)sum;
509 #pragma omp parallel for
510 for (
int i = 0; i < M; i++) {
511 for (
int j = 0; j < N; j++) {
512 float sum = bias ? bias[j] : 0.0f;
513 for (
int k = 0; k < K; k++) {
514 sum += A[k * M + i] * B[k * N + j];
531 #if defined(__AVX512F__)
533 #pragma omp parallel for
534 for (
int i = 0; i < M; i++) {
536 for (; j <= N - 16; j += 16) {
537 __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
538 for (
int k = 0; k < K; k++) {
539 __m512 a_broadcast = _mm512_set1_ps(A[k * M + i]);
540 __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
541 sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
543 _mm512_storeu_ps(&
C[i * N + j], sum_vec);
546 float sum = bias ? bias[j] : 0.0f;
547 for (
int k = 0; k < K; k++) {
548 sum += A[k * M + i] * B[k * N + j];
553 #elif defined(__AVX__)
555 #pragma omp parallel for
556 for (
int i = 0; i < M; i++) {
558 for (; j <= N - 8; j += 8) {
559 __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
560 for (
int k = 0; k < K; k++) {
561 __m256 a_broadcast = _mm256_set1_ps(A[k * M + i]);
562 __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
563 __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
564 sum_vec = _mm256_add_ps(sum_vec, prod);
566 _mm256_storeu_ps(&
C[i * N + j], sum_vec);
569 float sum = bias ? bias[j] : 0.0f;
570 for (
int k = 0; k < K; k++) {
571 sum += A[k * M + i] * B[k * N + j];
591 #if defined(__AVX512F__)
592 const int block_size = 64;
593 #elif defined(__AVX__)
594 const int block_size = 32;
596 const int block_size = 32;
599 #pragma omp parallel for
600 for (
int i = 0; i < M; i++) {
601 for (
int j = 0; j < N; j++) {
602 C[i * N + j] = bias ? bias[j] : 0.0f;
606 #pragma omp parallel for
607 for (
int ii = 0; ii < M; ii += block_size) {
608 for (
int kk = 0; kk < K; kk += block_size) {
609 for (
int jj = 0; jj < N; jj += block_size) {
610 int i_end =
ck_min(ii + block_size, M);
611 int k_end =
ck_min(kk + block_size, K);
612 int j_end =
ck_min(jj + block_size, N);
614 for (
int k = kk; k < k_end; k++) {
615 for (
int i = ii; i < i_end; i++) {
616 float a_val = A[k * M + i];
617 #if defined(__AVX512F__)
618 __m512 a_broadcast = _mm512_set1_ps(a_val);
620 for (j = jj; j <= j_end - 16; j += 16) {
621 __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
622 __m512 c_vec = _mm512_loadu_ps(&
C[i * N + j]);
623 c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
624 _mm512_storeu_ps(&
C[i * N + j], c_vec);
626 for (; j < j_end; j++) {
627 C[i * N + j] += a_val * B[k * N + j];
629 #elif defined(__AVX__)
630 __m256 a_broadcast = _mm256_set1_ps(a_val);
632 for (j = jj; j <= j_end - 8; j += 8) {
633 __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
634 __m256 c_vec = _mm256_loadu_ps(&
C[i * N + j]);
635 __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
636 c_vec = _mm256_add_ps(c_vec, prod);
637 _mm256_storeu_ps(&
C[i * N + j], c_vec);
639 for (; j < j_end; j++) {
640 C[i * N + j] += a_val * B[k * N + j];
643 for (
int j = jj; j < j_end; j++) {
644 C[i * N + j] += a_val * B[k * N + j];
678 if (M == 1 && (
size_t)N * (
size_t)K >= 65536) {
689 if (M >= 32 && N >= 32 && K >= 32) {
694 #if defined(__AVX512F__)
695 const int block_size = 64;
696 #elif defined(__AVX__)
697 const int block_size = 32;
699 const int block_size = 32;
701 for (
int i = 0; i < M; i++) {
702 for (
int j = 0; j < N; j++) {
703 C[i * N + j] = bias ? bias[j] : 0.0f;
706 for (
int ii = 0; ii < M; ii += block_size) {
707 for (
int jj = 0; jj < N; jj += block_size) {
708 for (
int kk = 0; kk < K; kk += block_size) {
709 int i_end =
ck_min(ii + block_size, M);
710 int j_end =
ck_min(jj + block_size, N);
711 int k_end =
ck_min(kk + block_size, K);
713 for (
int i = ii; i < i_end; i++) {
714 for (
int j = jj; j < j_end; j++) {
715 #if defined(__AVX512F__)
716 __m512 sum_vec = _mm512_setzero_ps();
718 for (k = kk; k <= k_end - 16; k += 16) {
719 __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
720 __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
721 sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
723 float partial_sum = _mm512_reduce_add_ps(sum_vec);
724 for (; k < k_end; k++) {
725 partial_sum += A[i * K + k] * B[j * K + k];
727 #elif defined(__AVX__)
728 __m256 sum_vec = _mm256_setzero_ps();
730 for (k = kk; k <= k_end - 8; k += 8) {
731 __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
732 __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
733 __m256 prod = _mm256_mul_ps(a_vec, b_vec);
734 sum_vec = _mm256_add_ps(sum_vec, prod);
736 float partial_sum = hsum256_ps(sum_vec);
737 for (; k < k_end; k++) {
738 partial_sum += A[i * K + k] * B[j * K + k];
741 float partial_sum = 0.0f;
742 for (
int k = kk; k < k_end; k++) {
743 partial_sum += A[i * K + k] * B[j * K + k];
746 C[i * N + j] += partial_sum;
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
int ck_get_num_threads(void)
int ck_strict_parity_enabled(void)
static void gemm_naive_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static int ck_min(int a, int b)
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static void gemm_nn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_nn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static void gemm_nt_matvec_parallel(const float *A, const float *B, const float *bias, float *C, int N, int K)
void gemm_tn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static void gemm_tn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_avx512_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_fine_grained_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static void ck_gemm_add_bias(float *C, const float *bias, int M, int N)
void gemm_tn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_nn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)