← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels.c File Reference

General matrix multiply (GEMM) kernels with SIMD (SSE/AVX/AVX512) More...

#include "ckernel_engine.h"
#include <omp.h>

Go to the source code of this file.

Functions

static void ck_gemm_add_bias (float *C, const float *bias, int M, int N)
 
static int ck_min (int a, int b)
 
void gemm_avx512_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_blocked_serial (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_fine_grained_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_naive_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
static void gemm_naive_serial_double (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
static void gemm_nn_serial_double (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
static void gemm_nt_matvec_parallel (const float *A, const float *B, const float *bias, float *C, int N, int K)
 
void gemm_tn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_tn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_tn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
static void gemm_tn_serial_double (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 

Detailed Description

General matrix multiply (GEMM) kernels with SIMD (SSE/AVX/AVX512)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

LEGACY EXCEPTION: This file contains OpenMP for backward compatibility. New kernels should NOT use OpenMP internally.

GEMM: C = alpha * A @ B + beta * C (with optional bias)

Definition in file gemm_kernels.c.

Function Documentation

◆ ck_gemm_add_bias()

static void ck_gemm_add_bias ( float *  C,
const float *  bias,
int  M,
int  N 
)
inlinestatic

Definition at line 28 of file gemm_kernels.c.

29 {
30  if (!bias) {
31  return;
32  }
33 #pragma omp parallel for schedule(static)
34  for (int i = 0; i < M; ++i) {
35  float *c_row = C + (size_t)i * (size_t)N;
36  for (int j = 0; j < N; ++j) {
37  c_row[j] += bias[j];
38  }
39  }
40 }
#define C(color)
Definition: show_config.c:39

References C.

Referenced by gemm_blocked_serial().

◆ ck_min()

static int ck_min ( int  a,
int  b 
)
inlinestatic

Definition at line 26 of file gemm_kernels.c.

26 { return a < b ? a : b; }

Referenced by gemm_blocked_serial(), gemm_fine_grained_parallel(), gemm_nn_blocked(), and gemm_tn_blocked().

◆ gemm_avx512_parallel()

void gemm_avx512_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 149 of file gemm_kernels.c.

154 {
155  if (ck_strict_parity_enabled()) {
156  gemm_naive_serial_double(A, B, bias, C, M, N, K);
157  return;
158  }
159 #if defined(__AVX512F__)
160 #pragma omp parallel for
161  for (int i = 0; i < M; i++) {
162  for (int j = 0; j < N; j++) {
163  __m512 sum_vec = _mm512_setzero_ps();
164  int k;
165  for (k = 0; k <= K - 16; k += 16) {
166  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
167  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
168  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
169  }
170  float sum = _mm512_reduce_add_ps(sum_vec);
171  for (; k < K; k++) {
172  sum += A[i * K + k] * B[j * K + k];
173  }
174  float bias_val = bias ? bias[j] : 0.0f;
175  C[i * N + j] = sum + bias_val;
176  }
177  }
178 #elif defined(__AVX__)
179  // AVX1 path: 256-bit vectors, no FMA (use mul + add)
180 #pragma omp parallel for
181  for (int i = 0; i < M; i++) {
182  for (int j = 0; j < N; j++) {
183  __m256 sum_vec = _mm256_setzero_ps();
184  int k;
185  for (k = 0; k <= K - 8; k += 8) {
186  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
187  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
188  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
189  sum_vec = _mm256_add_ps(sum_vec, prod);
190  }
191  float sum = hsum256_ps(sum_vec);
192  for (; k < K; k++) {
193  sum += A[i * K + k] * B[j * K + k];
194  }
195  float bias_val = bias ? bias[j] : 0.0f;
196  C[i * N + j] = sum + bias_val;
197  }
198  }
199 #else
200  gemm_naive_parallel(A, B, bias, C, M, N, K);
201 #endif
202 }
int ck_strict_parity_enabled(void)
static void gemm_naive_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:107
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:125

References C, ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().

◆ gemm_blocked_serial()

void gemm_blocked_serial ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 661 of file gemm_kernels.c.

666 {
667  // Ensure threads are initialized (auto-detects on first call)
668  (void)ck_get_num_threads();
669 
670  if (ck_strict_parity_enabled()) {
671  gemm_naive_serial_double(A, B, bias, C, M, N, K);
672  return;
673  }
674 
675  // Decode-time matvec (M=1) is extremely common and benefits from parallelism over N.
676  // Lower threshold to parallelize more ops; OpenMP overhead is ~1-2μs per barrier.
677  // For N*K >= 64K elements, parallel is worthwhile.
678  if (M == 1 && (size_t)N * (size_t)K >= 65536) {
679  gemm_nt_matvec_parallel(A, B, bias, C, N, K);
680  return;
681  }
682 
683  /*
684  * Use gemm_microkernel for large matrices - it uses MKL/oneDNN when available,
685  * which is substantially faster than our hand-written SIMD kernels.
686  * B is stored as [N x K] (transposed), so we pass B_transposed=1.
687  * Note: Use threshold of 32 to avoid numerical precision issues with small matrices.
688  */
689  if (M >= 32 && N >= 32 && K >= 32) {
690  gemm_microkernel(A, B, C, M, N, K, 1); // B_transposed=1
691  ck_gemm_add_bias(C, bias, M, N);
692  return;
693  }
694 #if defined(__AVX512F__)
695  const int block_size = 64;
696 #elif defined(__AVX__)
697  const int block_size = 32;
698 #else
699  const int block_size = 32;
700 #endif
701  for (int i = 0; i < M; i++) {
702  for (int j = 0; j < N; j++) {
703  C[i * N + j] = bias ? bias[j] : 0.0f;
704  }
705  }
706  for (int ii = 0; ii < M; ii += block_size) {
707  for (int jj = 0; jj < N; jj += block_size) {
708  for (int kk = 0; kk < K; kk += block_size) {
709  int i_end = ck_min(ii + block_size, M);
710  int j_end = ck_min(jj + block_size, N);
711  int k_end = ck_min(kk + block_size, K);
712 
713  for (int i = ii; i < i_end; i++) {
714  for (int j = jj; j < j_end; j++) {
715 #if defined(__AVX512F__)
716  __m512 sum_vec = _mm512_setzero_ps();
717  int k;
718  for (k = kk; k <= k_end - 16; k += 16) {
719  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
720  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
721  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
722  }
723  float partial_sum = _mm512_reduce_add_ps(sum_vec);
724  for (; k < k_end; k++) {
725  partial_sum += A[i * K + k] * B[j * K + k];
726  }
727 #elif defined(__AVX__)
728  __m256 sum_vec = _mm256_setzero_ps();
729  int k;
730  for (k = kk; k <= k_end - 8; k += 8) {
731  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
732  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
733  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
734  sum_vec = _mm256_add_ps(sum_vec, prod);
735  }
736  float partial_sum = hsum256_ps(sum_vec);
737  for (; k < k_end; k++) {
738  partial_sum += A[i * K + k] * B[j * K + k];
739  }
740 #else
741  float partial_sum = 0.0f;
742  for (int k = kk; k < k_end; k++) {
743  partial_sum += A[i * K + k] * B[j * K + k];
744  }
745 #endif
746  C[i * N + j] += partial_sum;
747  }
748  }
749  }
750  }
751  }
752 }
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
int ck_get_num_threads(void)
static int ck_min(int a, int b)
Definition: gemm_kernels.c:26
static void gemm_nt_matvec_parallel(const float *A, const float *B, const float *bias, float *C, int N, int K)
Definition: gemm_kernels.c:61
static void ck_gemm_add_bias(float *C, const float *bias, int M, int N)
Definition: gemm_kernels.c:28

References C, ck_gemm_add_bias(), ck_get_num_threads(), ck_min(), ck_strict_parity_enabled(), gemm_microkernel(), gemm_naive_serial_double(), and gemm_nt_matvec_parallel().

Referenced by ck_attention_project_head_major(), ck_gemm_nt_quant(), ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_fused_token(), ck_qkv_project_head_major(), ck_qkv_project_head_major_token(), mlp_token_parallel(), and mlp_token_parallel_exact().

◆ gemm_fine_grained_parallel()

void gemm_fine_grained_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 205 of file gemm_kernels.c.

210 {
211  if (ck_strict_parity_enabled()) {
212  gemm_naive_serial_double(A, B, bias, C, M, N, K);
213  return;
214  }
215 #if defined(__AVX512F__)
216  const int block_size = 64;
217 #pragma omp parallel for
218  for (int i = 0; i < M; i++) {
219  for (int j = 0; j < N; j++) {
220  C[i * N + j] = bias ? bias[j] : 0.0f;
221  }
222  }
223 #pragma omp parallel for collapse(3)
224  for (int ii = 0; ii < M; ii += block_size) {
225  for (int jj = 0; jj < N; jj += block_size) {
226  for (int kk = 0; kk < K; kk += block_size) {
227  int i_end = ck_min(ii + block_size, M);
228  int j_end = ck_min(jj + block_size, N);
229  int k_end = ck_min(kk + block_size, K);
230 
231  for (int i = ii; i < i_end; i++) {
232  for (int j = jj; j < j_end; j++) {
233  __m512 sum_vec = _mm512_setzero_ps();
234  int k;
235  for (k = kk; k <= k_end - 16; k += 16) {
236  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
237  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
238  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
239  }
240  float partial_sum = _mm512_reduce_add_ps(sum_vec);
241  for (; k < k_end; k++) {
242  partial_sum += A[i * K + k] * B[j * K + k];
243  }
244 #pragma omp atomic
245  C[i * N + j] += partial_sum;
246  }
247  }
248  }
249  }
250  }
251 #elif defined(__AVX__)
252  // AVX1 cache-blocked version
253  const int block_size = 32; // Smaller block for L1 cache
254 #pragma omp parallel for
255  for (int i = 0; i < M; i++) {
256  for (int j = 0; j < N; j++) {
257  C[i * N + j] = bias ? bias[j] : 0.0f;
258  }
259  }
260 #pragma omp parallel for collapse(3)
261  for (int ii = 0; ii < M; ii += block_size) {
262  for (int jj = 0; jj < N; jj += block_size) {
263  for (int kk = 0; kk < K; kk += block_size) {
264  int i_end = ck_min(ii + block_size, M);
265  int j_end = ck_min(jj + block_size, N);
266  int k_end = ck_min(kk + block_size, K);
267 
268  for (int i = ii; i < i_end; i++) {
269  for (int j = jj; j < j_end; j++) {
270  __m256 sum_vec = _mm256_setzero_ps();
271  int k;
272  for (k = kk; k <= k_end - 8; k += 8) {
273  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
274  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
275  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
276  sum_vec = _mm256_add_ps(sum_vec, prod);
277  }
278  float partial_sum = hsum256_ps(sum_vec);
279  for (; k < k_end; k++) {
280  partial_sum += A[i * K + k] * B[j * K + k];
281  }
282 #pragma omp atomic
283  C[i * N + j] += partial_sum;
284  }
285  }
286  }
287  }
288  }
289 #else
290  gemm_naive_parallel(A, B, bias, C, M, N, K);
291 #endif
292 }

References C, ck_min(), ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().

◆ gemm_naive_parallel()

void gemm_naive_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 125 of file gemm_kernels.c.

130 {
131  if (ck_strict_parity_enabled()) {
132  gemm_naive_serial_double(A, B, bias, C, M, N, K);
133  return;
134  }
135 #pragma omp parallel for
136  for (int i = 0; i < M; i++) {
137  for (int j = 0; j < N; j++) {
138  float sum = 0.0f;
139  for (int k = 0; k < K; k++) {
140  sum += A[i * K + k] * B[j * K + k];
141  }
142  float bias_val = bias ? bias[j] : 0.0f;
143  C[i * N + j] = sum + bias_val;
144  }
145  }
146 }

References C, ck_strict_parity_enabled(), and gemm_naive_serial_double().

Referenced by ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), gemm_avx512_parallel(), and gemm_fine_grained_parallel().

◆ gemm_naive_serial_double()

static void gemm_naive_serial_double ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)
static

Definition at line 107 of file gemm_kernels.c.

112 {
113  for (int i = 0; i < M; i++) {
114  for (int j = 0; j < N; j++) {
115  double sum = bias ? (double)bias[j] : 0.0;
116  for (int k = 0; k < K; k++) {
117  sum += (double)A[i * K + k] * (double)B[j * K + k];
118  }
119  C[i * N + j] = (float)sum;
120  }
121  }
122 }

References C.

Referenced by gemm_avx512_parallel(), gemm_blocked_serial(), gemm_fine_grained_parallel(), and gemm_naive_parallel().

◆ gemm_nn_avx512()

void gemm_nn_avx512 ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 339 of file gemm_kernels.c.

344 {
345  if (ck_strict_parity_enabled()) {
346  gemm_nn_serial_double(A, B, bias, C, M, N, K);
347  return;
348  }
349 #if defined(__AVX512F__)
350  // For gemm_nn, we can't vectorize over K easily since B[k,j] has stride N.
351  // Instead, vectorize over N (output columns) when N >= 16.
352 #pragma omp parallel for
353  for (int i = 0; i < M; i++) {
354  int j = 0;
355  // Process 16 output columns at a time
356  for (; j <= N - 16; j += 16) {
357  __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
358  for (int k = 0; k < K; k++) {
359  __m512 a_broadcast = _mm512_set1_ps(A[i * K + k]);
360  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
361  sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
362  }
363  _mm512_storeu_ps(&C[i * N + j], sum_vec);
364  }
365  // Handle remaining columns
366  for (; j < N; j++) {
367  float sum = bias ? bias[j] : 0.0f;
368  for (int k = 0; k < K; k++) {
369  sum += A[i * K + k] * B[k * N + j];
370  }
371  C[i * N + j] = sum;
372  }
373  }
374 #elif defined(__AVX__)
375  // AVX1: vectorize over N (8 columns at a time)
376 #pragma omp parallel for
377  for (int i = 0; i < M; i++) {
378  int j = 0;
379  for (; j <= N - 8; j += 8) {
380  __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
381  for (int k = 0; k < K; k++) {
382  __m256 a_broadcast = _mm256_set1_ps(A[i * K + k]);
383  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
384  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
385  sum_vec = _mm256_add_ps(sum_vec, prod);
386  }
387  _mm256_storeu_ps(&C[i * N + j], sum_vec);
388  }
389  for (; j < N; j++) {
390  float sum = bias ? bias[j] : 0.0f;
391  for (int k = 0; k < K; k++) {
392  sum += A[i * K + k] * B[k * N + j];
393  }
394  C[i * N + j] = sum;
395  }
396  }
397 #else
398  gemm_nn_parallel(A, B, bias, C, M, N, K);
399 #endif
400 }
static void gemm_nn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:300
void gemm_nn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:317

References C, ck_strict_parity_enabled(), gemm_nn_parallel(), and gemm_nn_serial_double().

Referenced by fc1_backward_kernel(), and fc2_backward_kernel().

◆ gemm_nn_blocked()

void gemm_nn_blocked ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 402 of file gemm_kernels.c.

407 {
408  if (ck_strict_parity_enabled()) {
409  gemm_nn_serial_double(A, B, bias, C, M, N, K);
410  return;
411  }
412 #if defined(__AVX512F__)
413  const int block_size = 64;
414 #elif defined(__AVX__)
415  const int block_size = 32;
416 #else
417  const int block_size = 32;
418 #endif
419  // Initialize C with bias (parallelized)
420 #pragma omp parallel for
421  for (int i = 0; i < M; i++) {
422  for (int j = 0; j < N; j++) {
423  C[i * N + j] = bias ? bias[j] : 0.0f;
424  }
425  }
426  // Blocked multiply-accumulate (parallelized over M blocks)
427 #pragma omp parallel for
428  for (int ii = 0; ii < M; ii += block_size) {
429  for (int kk = 0; kk < K; kk += block_size) {
430  for (int jj = 0; jj < N; jj += block_size) {
431  int i_end = ck_min(ii + block_size, M);
432  int k_end = ck_min(kk + block_size, K);
433  int j_end = ck_min(jj + block_size, N);
434 
435  for (int i = ii; i < i_end; i++) {
436  for (int k = kk; k < k_end; k++) {
437  float a_val = A[i * K + k];
438 #if defined(__AVX512F__)
439  __m512 a_broadcast = _mm512_set1_ps(a_val);
440  int j;
441  for (j = jj; j <= j_end - 16; j += 16) {
442  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
443  __m512 c_vec = _mm512_loadu_ps(&C[i * N + j]);
444  c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
445  _mm512_storeu_ps(&C[i * N + j], c_vec);
446  }
447  for (; j < j_end; j++) {
448  C[i * N + j] += a_val * B[k * N + j];
449  }
450 #elif defined(__AVX__)
451  __m256 a_broadcast = _mm256_set1_ps(a_val);
452  int j;
453  for (j = jj; j <= j_end - 8; j += 8) {
454  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
455  __m256 c_vec = _mm256_loadu_ps(&C[i * N + j]);
456  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
457  c_vec = _mm256_add_ps(c_vec, prod);
458  _mm256_storeu_ps(&C[i * N + j], c_vec);
459  }
460  for (; j < j_end; j++) {
461  C[i * N + j] += a_val * B[k * N + j];
462  }
463 #else
464  for (int j = jj; j < j_end; j++) {
465  C[i * N + j] += a_val * B[k * N + j];
466  }
467 #endif
468  }
469  }
470  }
471  }
472  }
473 }

References C, ck_min(), ck_strict_parity_enabled(), and gemm_nn_serial_double().

◆ gemm_nn_parallel()

void gemm_nn_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 317 of file gemm_kernels.c.

322 {
323  if (ck_strict_parity_enabled()) {
324  gemm_nn_serial_double(A, B, bias, C, M, N, K);
325  return;
326  }
327 #pragma omp parallel for
328  for (int i = 0; i < M; i++) {
329  for (int j = 0; j < N; j++) {
330  float sum = bias ? bias[j] : 0.0f;
331  for (int k = 0; k < K; k++) {
332  sum += A[i * K + k] * B[k * N + j];
333  }
334  C[i * N + j] = sum;
335  }
336  }
337 }

References C, ck_strict_parity_enabled(), and gemm_nn_serial_double().

Referenced by gemm_nn_avx512().

◆ gemm_nn_serial_double()

static void gemm_nn_serial_double ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)
static

Definition at line 300 of file gemm_kernels.c.

305 {
306  for (int i = 0; i < M; i++) {
307  for (int j = 0; j < N; j++) {
308  double sum = bias ? (double)bias[j] : 0.0;
309  for (int k = 0; k < K; k++) {
310  sum += (double)A[i * K + k] * (double)B[k * N + j];
311  }
312  C[i * N + j] = (float)sum;
313  }
314  }
315 }

References C.

Referenced by gemm_nn_avx512(), gemm_nn_blocked(), and gemm_nn_parallel().

◆ gemm_nt_matvec_parallel()

static void gemm_nt_matvec_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  N,
int  K 
)
static

Definition at line 61 of file gemm_kernels.c.

67 {
68 #pragma omp parallel for schedule(static)
69  for (int j = 0; j < N; ++j) {
70  const float *b_row = B + (size_t)j * (size_t)K;
71  float sum = bias ? bias[j] : 0.0f;
72 
73 #if defined(__AVX512F__)
74  __m512 acc = _mm512_setzero_ps();
75  int k = 0;
76  for (; k <= K - 16; k += 16) {
77  __m512 a_vec = _mm512_loadu_ps(A + k);
78  __m512 b_vec = _mm512_loadu_ps(b_row + k);
79  acc = _mm512_fmadd_ps(a_vec, b_vec, acc);
80  }
81  sum += _mm512_reduce_add_ps(acc);
82  for (; k < K; ++k) {
83  sum += A[k] * b_row[k];
84  }
85 #elif defined(__AVX__)
86  __m256 acc = _mm256_setzero_ps();
87  int k = 0;
88  for (; k <= K - 8; k += 8) {
89  __m256 a_vec = _mm256_loadu_ps(A + k);
90  __m256 b_vec = _mm256_loadu_ps(b_row + k);
91  acc = _mm256_add_ps(acc, _mm256_mul_ps(a_vec, b_vec));
92  }
93  sum += hsum256_ps(acc);
94  for (; k < K; ++k) {
95  sum += A[k] * b_row[k];
96  }
97 #else
98  for (int k = 0; k < K; ++k) {
99  sum += A[k] * b_row[k];
100  }
101 #endif
102 
103  C[j] = sum;
104  }
105 }

References C.

Referenced by gemm_blocked_serial().

◆ gemm_tn_avx512()

void gemm_tn_avx512 ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 521 of file gemm_kernels.c.

526 {
527  if (ck_strict_parity_enabled()) {
528  gemm_tn_serial_double(A, B, bias, C, M, N, K);
529  return;
530  }
531 #if defined(__AVX512F__)
532  // Vectorize over N (output columns)
533 #pragma omp parallel for
534  for (int i = 0; i < M; i++) {
535  int j = 0;
536  for (; j <= N - 16; j += 16) {
537  __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
538  for (int k = 0; k < K; k++) {
539  __m512 a_broadcast = _mm512_set1_ps(A[k * M + i]);
540  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
541  sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
542  }
543  _mm512_storeu_ps(&C[i * N + j], sum_vec);
544  }
545  for (; j < N; j++) {
546  float sum = bias ? bias[j] : 0.0f;
547  for (int k = 0; k < K; k++) {
548  sum += A[k * M + i] * B[k * N + j];
549  }
550  C[i * N + j] = sum;
551  }
552  }
553 #elif defined(__AVX__)
554  // AVX1: vectorize over N (8 columns at a time)
555 #pragma omp parallel for
556  for (int i = 0; i < M; i++) {
557  int j = 0;
558  for (; j <= N - 8; j += 8) {
559  __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
560  for (int k = 0; k < K; k++) {
561  __m256 a_broadcast = _mm256_set1_ps(A[k * M + i]);
562  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
563  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
564  sum_vec = _mm256_add_ps(sum_vec, prod);
565  }
566  _mm256_storeu_ps(&C[i * N + j], sum_vec);
567  }
568  for (; j < N; j++) {
569  float sum = bias ? bias[j] : 0.0f;
570  for (int k = 0; k < K; k++) {
571  sum += A[k * M + i] * B[k * N + j];
572  }
573  C[i * N + j] = sum;
574  }
575  }
576 #else
577  gemm_tn_parallel(A, B, bias, C, M, N, K);
578 #endif
579 }
void gemm_tn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:499
static void gemm_tn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:481

References C, ck_strict_parity_enabled(), gemm_tn_parallel(), and gemm_tn_serial_double().

Referenced by fc1_backward_kernel(), and fc2_backward_kernel().

◆ gemm_tn_blocked()

void gemm_tn_blocked ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 581 of file gemm_kernels.c.

586 {
587  if (ck_strict_parity_enabled()) {
588  gemm_tn_serial_double(A, B, bias, C, M, N, K);
589  return;
590  }
591 #if defined(__AVX512F__)
592  const int block_size = 64;
593 #elif defined(__AVX__)
594  const int block_size = 32;
595 #else
596  const int block_size = 32;
597 #endif
598  // Initialize C with bias (parallelized)
599 #pragma omp parallel for
600  for (int i = 0; i < M; i++) {
601  for (int j = 0; j < N; j++) {
602  C[i * N + j] = bias ? bias[j] : 0.0f;
603  }
604  }
605  // Blocked multiply-accumulate (parallelized over M blocks)
606 #pragma omp parallel for
607  for (int ii = 0; ii < M; ii += block_size) {
608  for (int kk = 0; kk < K; kk += block_size) {
609  for (int jj = 0; jj < N; jj += block_size) {
610  int i_end = ck_min(ii + block_size, M);
611  int k_end = ck_min(kk + block_size, K);
612  int j_end = ck_min(jj + block_size, N);
613 
614  for (int k = kk; k < k_end; k++) {
615  for (int i = ii; i < i_end; i++) {
616  float a_val = A[k * M + i];
617 #if defined(__AVX512F__)
618  __m512 a_broadcast = _mm512_set1_ps(a_val);
619  int j;
620  for (j = jj; j <= j_end - 16; j += 16) {
621  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
622  __m512 c_vec = _mm512_loadu_ps(&C[i * N + j]);
623  c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
624  _mm512_storeu_ps(&C[i * N + j], c_vec);
625  }
626  for (; j < j_end; j++) {
627  C[i * N + j] += a_val * B[k * N + j];
628  }
629 #elif defined(__AVX__)
630  __m256 a_broadcast = _mm256_set1_ps(a_val);
631  int j;
632  for (j = jj; j <= j_end - 8; j += 8) {
633  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
634  __m256 c_vec = _mm256_loadu_ps(&C[i * N + j]);
635  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
636  c_vec = _mm256_add_ps(c_vec, prod);
637  _mm256_storeu_ps(&C[i * N + j], c_vec);
638  }
639  for (; j < j_end; j++) {
640  C[i * N + j] += a_val * B[k * N + j];
641  }
642 #else
643  for (int j = jj; j < j_end; j++) {
644  C[i * N + j] += a_val * B[k * N + j];
645  }
646 #endif
647  }
648  }
649  }
650  }
651  }
652 }

References C, ck_min(), ck_strict_parity_enabled(), and gemm_tn_serial_double().

◆ gemm_tn_parallel()

void gemm_tn_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 499 of file gemm_kernels.c.

504 {
505  if (ck_strict_parity_enabled()) {
506  gemm_tn_serial_double(A, B, bias, C, M, N, K);
507  return;
508  }
509 #pragma omp parallel for
510  for (int i = 0; i < M; i++) {
511  for (int j = 0; j < N; j++) {
512  float sum = bias ? bias[j] : 0.0f;
513  for (int k = 0; k < K; k++) {
514  sum += A[k * M + i] * B[k * N + j];
515  }
516  C[i * N + j] = sum;
517  }
518  }
519 }

References C, ck_strict_parity_enabled(), and gemm_tn_serial_double().

Referenced by gemm_tn_avx512().

◆ gemm_tn_serial_double()

static void gemm_tn_serial_double ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)
static

Definition at line 481 of file gemm_kernels.c.

486 {
487  for (int i = 0; i < M; i++) {
488  for (int j = 0; j < N; j++) {
489  double sum = bias ? (double)bias[j] : 0.0;
490  for (int k = 0; k < K; k++) {
491  // A.T[i,k] = A[k,i] = A[k*M + i]
492  sum += (double)A[k * M + i] * (double)B[k * N + j];
493  }
494  C[i * N + j] = (float)sum;
495  }
496  }
497 }

References C.

Referenced by gemm_tn_avx512(), gemm_tn_blocked(), and gemm_tn_parallel().