← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels.c
3  * @brief General matrix multiply (GEMM) kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * LEGACY EXCEPTION: This file contains OpenMP for backward compatibility.
15  * New kernels should NOT use OpenMP internally.
16  *
17  * GEMM: C = alpha * A @ B + beta * C (with optional bias)
18  */
19 
20 #include "ckernel_engine.h"
21 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
22 #include <immintrin.h>
23 #endif
24 #include <omp.h>
25 
26 static inline int ck_min(int a, int b) { return a < b ? a : b; }
27 
28 static inline void ck_gemm_add_bias(float *C, const float *bias, int M, int N)
29 {
30  if (!bias) {
31  return;
32  }
33 #pragma omp parallel for schedule(static)
34  for (int i = 0; i < M; ++i) {
35  float *c_row = C + (size_t)i * (size_t)N;
36  for (int j = 0; j < N; ++j) {
37  c_row[j] += bias[j];
38  }
39  }
40 }
41 
42 // AVX1 horizontal sum helper (no _mm256_reduce_add_ps in AVX1)
43 #if defined(__AVX__) && !defined(__AVX512F__)
44 static inline float hsum256_ps(__m256 v) {
45  // Sum upper and lower 128-bit lanes
46  __m128 lo = _mm256_castps256_ps128(v);
47  __m128 hi = _mm256_extractf128_ps(v, 1);
48  __m128 sum128 = _mm_add_ps(lo, hi);
49  // Horizontal add within 128-bit
50  __m128 shuf = _mm_movehdup_ps(sum128); // [1,1,3,3]
51  __m128 sums = _mm_add_ps(sum128, shuf); // [0+1,1+1,2+3,3+3]
52  shuf = _mm_movehl_ps(shuf, sums); // [2+3,3+3,...]
53  sums = _mm_add_ss(sums, shuf); // [0+1+2+3,...]
54  return _mm_cvtss_f32(sums);
55 }
56 #endif
57 
58 // Fast path for M=1: parallelize across output channels (j).
59 // This is the common decode-time shape (matrix-vector) and is otherwise single-threaded
60 // in the blocked GEMM code because M=1 provides no parallelism on the row dimension.
61 static void gemm_nt_matvec_parallel(const float *A, // [K]
62  const float *B, // [N x K] (row-major, transposed layout)
63  const float *bias, // [N] or NULL
64  float *C, // [N]
65  int N,
66  int K)
67 {
68 #pragma omp parallel for schedule(static)
69  for (int j = 0; j < N; ++j) {
70  const float *b_row = B + (size_t)j * (size_t)K;
71  float sum = bias ? bias[j] : 0.0f;
72 
73 #if defined(__AVX512F__)
74  __m512 acc = _mm512_setzero_ps();
75  int k = 0;
76  for (; k <= K - 16; k += 16) {
77  __m512 a_vec = _mm512_loadu_ps(A + k);
78  __m512 b_vec = _mm512_loadu_ps(b_row + k);
79  acc = _mm512_fmadd_ps(a_vec, b_vec, acc);
80  }
81  sum += _mm512_reduce_add_ps(acc);
82  for (; k < K; ++k) {
83  sum += A[k] * b_row[k];
84  }
85 #elif defined(__AVX__)
86  __m256 acc = _mm256_setzero_ps();
87  int k = 0;
88  for (; k <= K - 8; k += 8) {
89  __m256 a_vec = _mm256_loadu_ps(A + k);
90  __m256 b_vec = _mm256_loadu_ps(b_row + k);
91  acc = _mm256_add_ps(acc, _mm256_mul_ps(a_vec, b_vec));
92  }
93  sum += hsum256_ps(acc);
94  for (; k < K; ++k) {
95  sum += A[k] * b_row[k];
96  }
97 #else
98  for (int k = 0; k < K; ++k) {
99  sum += A[k] * b_row[k];
100  }
101 #endif
102 
103  C[j] = sum;
104  }
105 }
106 
107 static void gemm_naive_serial_double(const float *A,
108  const float *B,
109  const float *bias,
110  float *C,
111  int M, int N, int K)
112 {
113  for (int i = 0; i < M; i++) {
114  for (int j = 0; j < N; j++) {
115  double sum = bias ? (double)bias[j] : 0.0;
116  for (int k = 0; k < K; k++) {
117  sum += (double)A[i * K + k] * (double)B[j * K + k];
118  }
119  C[i * N + j] = (float)sum;
120  }
121  }
122 }
123 
124 // Naive parallel GEMM (reference baseline) – copied from C-Transformer.
125 void gemm_naive_parallel(const float *A,
126  const float *B,
127  const float *bias,
128  float *C,
129  int M, int N, int K)
130 {
131  if (ck_strict_parity_enabled()) {
132  gemm_naive_serial_double(A, B, bias, C, M, N, K);
133  return;
134  }
135 #pragma omp parallel for
136  for (int i = 0; i < M; i++) {
137  for (int j = 0; j < N; j++) {
138  float sum = 0.0f;
139  for (int k = 0; k < K; k++) {
140  sum += A[i * K + k] * B[j * K + k];
141  }
142  float bias_val = bias ? bias[j] : 0.0f;
143  C[i * N + j] = sum + bias_val;
144  }
145  }
146 }
147 
148 // AVX-512 optimized GEMM with AVX1 fallback
149 void gemm_avx512_parallel(const float *A,
150  const float *B,
151  const float *bias,
152  float *C,
153  int M, int N, int K)
154 {
155  if (ck_strict_parity_enabled()) {
156  gemm_naive_serial_double(A, B, bias, C, M, N, K);
157  return;
158  }
159 #if defined(__AVX512F__)
160 #pragma omp parallel for
161  for (int i = 0; i < M; i++) {
162  for (int j = 0; j < N; j++) {
163  __m512 sum_vec = _mm512_setzero_ps();
164  int k;
165  for (k = 0; k <= K - 16; k += 16) {
166  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
167  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
168  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
169  }
170  float sum = _mm512_reduce_add_ps(sum_vec);
171  for (; k < K; k++) {
172  sum += A[i * K + k] * B[j * K + k];
173  }
174  float bias_val = bias ? bias[j] : 0.0f;
175  C[i * N + j] = sum + bias_val;
176  }
177  }
178 #elif defined(__AVX__)
179  // AVX1 path: 256-bit vectors, no FMA (use mul + add)
180 #pragma omp parallel for
181  for (int i = 0; i < M; i++) {
182  for (int j = 0; j < N; j++) {
183  __m256 sum_vec = _mm256_setzero_ps();
184  int k;
185  for (k = 0; k <= K - 8; k += 8) {
186  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
187  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
188  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
189  sum_vec = _mm256_add_ps(sum_vec, prod);
190  }
191  float sum = hsum256_ps(sum_vec);
192  for (; k < K; k++) {
193  sum += A[i * K + k] * B[j * K + k];
194  }
195  float bias_val = bias ? bias[j] : 0.0f;
196  C[i * N + j] = sum + bias_val;
197  }
198  }
199 #else
200  gemm_naive_parallel(A, B, bias, C, M, N, K);
201 #endif
202 }
203 
204 // Cache-blocked GEMM with fine-grained parallelism and AVX1 fallback
205 void gemm_fine_grained_parallel(const float *A,
206  const float *B,
207  const float *bias,
208  float *C,
209  int M, int N, int K)
210 {
211  if (ck_strict_parity_enabled()) {
212  gemm_naive_serial_double(A, B, bias, C, M, N, K);
213  return;
214  }
215 #if defined(__AVX512F__)
216  const int block_size = 64;
217 #pragma omp parallel for
218  for (int i = 0; i < M; i++) {
219  for (int j = 0; j < N; j++) {
220  C[i * N + j] = bias ? bias[j] : 0.0f;
221  }
222  }
223 #pragma omp parallel for collapse(3)
224  for (int ii = 0; ii < M; ii += block_size) {
225  for (int jj = 0; jj < N; jj += block_size) {
226  for (int kk = 0; kk < K; kk += block_size) {
227  int i_end = ck_min(ii + block_size, M);
228  int j_end = ck_min(jj + block_size, N);
229  int k_end = ck_min(kk + block_size, K);
230 
231  for (int i = ii; i < i_end; i++) {
232  for (int j = jj; j < j_end; j++) {
233  __m512 sum_vec = _mm512_setzero_ps();
234  int k;
235  for (k = kk; k <= k_end - 16; k += 16) {
236  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
237  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
238  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
239  }
240  float partial_sum = _mm512_reduce_add_ps(sum_vec);
241  for (; k < k_end; k++) {
242  partial_sum += A[i * K + k] * B[j * K + k];
243  }
244 #pragma omp atomic
245  C[i * N + j] += partial_sum;
246  }
247  }
248  }
249  }
250  }
251 #elif defined(__AVX__)
252  // AVX1 cache-blocked version
253  const int block_size = 32; // Smaller block for L1 cache
254 #pragma omp parallel for
255  for (int i = 0; i < M; i++) {
256  for (int j = 0; j < N; j++) {
257  C[i * N + j] = bias ? bias[j] : 0.0f;
258  }
259  }
260 #pragma omp parallel for collapse(3)
261  for (int ii = 0; ii < M; ii += block_size) {
262  for (int jj = 0; jj < N; jj += block_size) {
263  for (int kk = 0; kk < K; kk += block_size) {
264  int i_end = ck_min(ii + block_size, M);
265  int j_end = ck_min(jj + block_size, N);
266  int k_end = ck_min(kk + block_size, K);
267 
268  for (int i = ii; i < i_end; i++) {
269  for (int j = jj; j < j_end; j++) {
270  __m256 sum_vec = _mm256_setzero_ps();
271  int k;
272  for (k = kk; k <= k_end - 8; k += 8) {
273  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
274  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
275  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
276  sum_vec = _mm256_add_ps(sum_vec, prod);
277  }
278  float partial_sum = hsum256_ps(sum_vec);
279  for (; k < k_end; k++) {
280  partial_sum += A[i * K + k] * B[j * K + k];
281  }
282 #pragma omp atomic
283  C[i * N + j] += partial_sum;
284  }
285  }
286  }
287  }
288  }
289 #else
290  gemm_naive_parallel(A, B, bias, C, M, N, K);
291 #endif
292 }
293 
294 // =============================================================================
295 // GEMM_NN: C[M,N] = A[M,K] @ B[K,N] + bias[N]
296 // B is stored row-major as [K,N] (no transpose)
297 // Used for backward d_input computation: d_input = d_output @ W
298 // =============================================================================
299 
300 static void gemm_nn_serial_double(const float *A,
301  const float *B,
302  const float *bias,
303  float *C,
304  int M, int N, int K)
305 {
306  for (int i = 0; i < M; i++) {
307  for (int j = 0; j < N; j++) {
308  double sum = bias ? (double)bias[j] : 0.0;
309  for (int k = 0; k < K; k++) {
310  sum += (double)A[i * K + k] * (double)B[k * N + j];
311  }
312  C[i * N + j] = (float)sum;
313  }
314  }
315 }
316 
317 void gemm_nn_parallel(const float *A,
318  const float *B,
319  const float *bias,
320  float *C,
321  int M, int N, int K)
322 {
323  if (ck_strict_parity_enabled()) {
324  gemm_nn_serial_double(A, B, bias, C, M, N, K);
325  return;
326  }
327 #pragma omp parallel for
328  for (int i = 0; i < M; i++) {
329  for (int j = 0; j < N; j++) {
330  float sum = bias ? bias[j] : 0.0f;
331  for (int k = 0; k < K; k++) {
332  sum += A[i * K + k] * B[k * N + j];
333  }
334  C[i * N + j] = sum;
335  }
336  }
337 }
338 
339 void gemm_nn_avx512(const float *A,
340  const float *B,
341  const float *bias,
342  float *C,
343  int M, int N, int K)
344 {
345  if (ck_strict_parity_enabled()) {
346  gemm_nn_serial_double(A, B, bias, C, M, N, K);
347  return;
348  }
349 #if defined(__AVX512F__)
350  // For gemm_nn, we can't vectorize over K easily since B[k,j] has stride N.
351  // Instead, vectorize over N (output columns) when N >= 16.
352 #pragma omp parallel for
353  for (int i = 0; i < M; i++) {
354  int j = 0;
355  // Process 16 output columns at a time
356  for (; j <= N - 16; j += 16) {
357  __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
358  for (int k = 0; k < K; k++) {
359  __m512 a_broadcast = _mm512_set1_ps(A[i * K + k]);
360  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
361  sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
362  }
363  _mm512_storeu_ps(&C[i * N + j], sum_vec);
364  }
365  // Handle remaining columns
366  for (; j < N; j++) {
367  float sum = bias ? bias[j] : 0.0f;
368  for (int k = 0; k < K; k++) {
369  sum += A[i * K + k] * B[k * N + j];
370  }
371  C[i * N + j] = sum;
372  }
373  }
374 #elif defined(__AVX__)
375  // AVX1: vectorize over N (8 columns at a time)
376 #pragma omp parallel for
377  for (int i = 0; i < M; i++) {
378  int j = 0;
379  for (; j <= N - 8; j += 8) {
380  __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
381  for (int k = 0; k < K; k++) {
382  __m256 a_broadcast = _mm256_set1_ps(A[i * K + k]);
383  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
384  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
385  sum_vec = _mm256_add_ps(sum_vec, prod);
386  }
387  _mm256_storeu_ps(&C[i * N + j], sum_vec);
388  }
389  for (; j < N; j++) {
390  float sum = bias ? bias[j] : 0.0f;
391  for (int k = 0; k < K; k++) {
392  sum += A[i * K + k] * B[k * N + j];
393  }
394  C[i * N + j] = sum;
395  }
396  }
397 #else
398  gemm_nn_parallel(A, B, bias, C, M, N, K);
399 #endif
400 }
401 
402 void gemm_nn_blocked(const float *A,
403  const float *B,
404  const float *bias,
405  float *C,
406  int M, int N, int K)
407 {
408  if (ck_strict_parity_enabled()) {
409  gemm_nn_serial_double(A, B, bias, C, M, N, K);
410  return;
411  }
412 #if defined(__AVX512F__)
413  const int block_size = 64;
414 #elif defined(__AVX__)
415  const int block_size = 32;
416 #else
417  const int block_size = 32;
418 #endif
419  // Initialize C with bias (parallelized)
420 #pragma omp parallel for
421  for (int i = 0; i < M; i++) {
422  for (int j = 0; j < N; j++) {
423  C[i * N + j] = bias ? bias[j] : 0.0f;
424  }
425  }
426  // Blocked multiply-accumulate (parallelized over M blocks)
427 #pragma omp parallel for
428  for (int ii = 0; ii < M; ii += block_size) {
429  for (int kk = 0; kk < K; kk += block_size) {
430  for (int jj = 0; jj < N; jj += block_size) {
431  int i_end = ck_min(ii + block_size, M);
432  int k_end = ck_min(kk + block_size, K);
433  int j_end = ck_min(jj + block_size, N);
434 
435  for (int i = ii; i < i_end; i++) {
436  for (int k = kk; k < k_end; k++) {
437  float a_val = A[i * K + k];
438 #if defined(__AVX512F__)
439  __m512 a_broadcast = _mm512_set1_ps(a_val);
440  int j;
441  for (j = jj; j <= j_end - 16; j += 16) {
442  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
443  __m512 c_vec = _mm512_loadu_ps(&C[i * N + j]);
444  c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
445  _mm512_storeu_ps(&C[i * N + j], c_vec);
446  }
447  for (; j < j_end; j++) {
448  C[i * N + j] += a_val * B[k * N + j];
449  }
450 #elif defined(__AVX__)
451  __m256 a_broadcast = _mm256_set1_ps(a_val);
452  int j;
453  for (j = jj; j <= j_end - 8; j += 8) {
454  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
455  __m256 c_vec = _mm256_loadu_ps(&C[i * N + j]);
456  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
457  c_vec = _mm256_add_ps(c_vec, prod);
458  _mm256_storeu_ps(&C[i * N + j], c_vec);
459  }
460  for (; j < j_end; j++) {
461  C[i * N + j] += a_val * B[k * N + j];
462  }
463 #else
464  for (int j = jj; j < j_end; j++) {
465  C[i * N + j] += a_val * B[k * N + j];
466  }
467 #endif
468  }
469  }
470  }
471  }
472  }
473 }
474 
475 // =============================================================================
476 // GEMM_TN: C[M,N] = A[K,M].T @ B[K,N] + bias[N]
477 // A is stored row-major as [K,M], B is stored row-major as [K,N]
478 // Used for backward d_W computation: d_W = d_output.T @ input
479 // =============================================================================
480 
481 static void gemm_tn_serial_double(const float *A,
482  const float *B,
483  const float *bias,
484  float *C,
485  int M, int N, int K)
486 {
487  for (int i = 0; i < M; i++) {
488  for (int j = 0; j < N; j++) {
489  double sum = bias ? (double)bias[j] : 0.0;
490  for (int k = 0; k < K; k++) {
491  // A.T[i,k] = A[k,i] = A[k*M + i]
492  sum += (double)A[k * M + i] * (double)B[k * N + j];
493  }
494  C[i * N + j] = (float)sum;
495  }
496  }
497 }
498 
499 void gemm_tn_parallel(const float *A,
500  const float *B,
501  const float *bias,
502  float *C,
503  int M, int N, int K)
504 {
505  if (ck_strict_parity_enabled()) {
506  gemm_tn_serial_double(A, B, bias, C, M, N, K);
507  return;
508  }
509 #pragma omp parallel for
510  for (int i = 0; i < M; i++) {
511  for (int j = 0; j < N; j++) {
512  float sum = bias ? bias[j] : 0.0f;
513  for (int k = 0; k < K; k++) {
514  sum += A[k * M + i] * B[k * N + j];
515  }
516  C[i * N + j] = sum;
517  }
518  }
519 }
520 
521 void gemm_tn_avx512(const float *A,
522  const float *B,
523  const float *bias,
524  float *C,
525  int M, int N, int K)
526 {
527  if (ck_strict_parity_enabled()) {
528  gemm_tn_serial_double(A, B, bias, C, M, N, K);
529  return;
530  }
531 #if defined(__AVX512F__)
532  // Vectorize over N (output columns)
533 #pragma omp parallel for
534  for (int i = 0; i < M; i++) {
535  int j = 0;
536  for (; j <= N - 16; j += 16) {
537  __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
538  for (int k = 0; k < K; k++) {
539  __m512 a_broadcast = _mm512_set1_ps(A[k * M + i]);
540  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
541  sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
542  }
543  _mm512_storeu_ps(&C[i * N + j], sum_vec);
544  }
545  for (; j < N; j++) {
546  float sum = bias ? bias[j] : 0.0f;
547  for (int k = 0; k < K; k++) {
548  sum += A[k * M + i] * B[k * N + j];
549  }
550  C[i * N + j] = sum;
551  }
552  }
553 #elif defined(__AVX__)
554  // AVX1: vectorize over N (8 columns at a time)
555 #pragma omp parallel for
556  for (int i = 0; i < M; i++) {
557  int j = 0;
558  for (; j <= N - 8; j += 8) {
559  __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
560  for (int k = 0; k < K; k++) {
561  __m256 a_broadcast = _mm256_set1_ps(A[k * M + i]);
562  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
563  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
564  sum_vec = _mm256_add_ps(sum_vec, prod);
565  }
566  _mm256_storeu_ps(&C[i * N + j], sum_vec);
567  }
568  for (; j < N; j++) {
569  float sum = bias ? bias[j] : 0.0f;
570  for (int k = 0; k < K; k++) {
571  sum += A[k * M + i] * B[k * N + j];
572  }
573  C[i * N + j] = sum;
574  }
575  }
576 #else
577  gemm_tn_parallel(A, B, bias, C, M, N, K);
578 #endif
579 }
580 
581 void gemm_tn_blocked(const float *A,
582  const float *B,
583  const float *bias,
584  float *C,
585  int M, int N, int K)
586 {
587  if (ck_strict_parity_enabled()) {
588  gemm_tn_serial_double(A, B, bias, C, M, N, K);
589  return;
590  }
591 #if defined(__AVX512F__)
592  const int block_size = 64;
593 #elif defined(__AVX__)
594  const int block_size = 32;
595 #else
596  const int block_size = 32;
597 #endif
598  // Initialize C with bias (parallelized)
599 #pragma omp parallel for
600  for (int i = 0; i < M; i++) {
601  for (int j = 0; j < N; j++) {
602  C[i * N + j] = bias ? bias[j] : 0.0f;
603  }
604  }
605  // Blocked multiply-accumulate (parallelized over M blocks)
606 #pragma omp parallel for
607  for (int ii = 0; ii < M; ii += block_size) {
608  for (int kk = 0; kk < K; kk += block_size) {
609  for (int jj = 0; jj < N; jj += block_size) {
610  int i_end = ck_min(ii + block_size, M);
611  int k_end = ck_min(kk + block_size, K);
612  int j_end = ck_min(jj + block_size, N);
613 
614  for (int k = kk; k < k_end; k++) {
615  for (int i = ii; i < i_end; i++) {
616  float a_val = A[k * M + i];
617 #if defined(__AVX512F__)
618  __m512 a_broadcast = _mm512_set1_ps(a_val);
619  int j;
620  for (j = jj; j <= j_end - 16; j += 16) {
621  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
622  __m512 c_vec = _mm512_loadu_ps(&C[i * N + j]);
623  c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
624  _mm512_storeu_ps(&C[i * N + j], c_vec);
625  }
626  for (; j < j_end; j++) {
627  C[i * N + j] += a_val * B[k * N + j];
628  }
629 #elif defined(__AVX__)
630  __m256 a_broadcast = _mm256_set1_ps(a_val);
631  int j;
632  for (j = jj; j <= j_end - 8; j += 8) {
633  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
634  __m256 c_vec = _mm256_loadu_ps(&C[i * N + j]);
635  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
636  c_vec = _mm256_add_ps(c_vec, prod);
637  _mm256_storeu_ps(&C[i * N + j], c_vec);
638  }
639  for (; j < j_end; j++) {
640  C[i * N + j] += a_val * B[k * N + j];
641  }
642 #else
643  for (int j = jj; j < j_end; j++) {
644  C[i * N + j] += a_val * B[k * N + j];
645  }
646 #endif
647  }
648  }
649  }
650  }
651  }
652 }
653 
654 // =============================================================================
655 // Original GEMM_NT: C[M,N] = A[M,K] @ B[N,K].T + bias[N]
656 // B is stored row-major as [N,K] (transposed in the multiply)
657 // =============================================================================
658 
659 // Serial cache-blocked GEMM with SIMD (AVX/AVX512).
660 // Note: B is stored as [N x K] (transposed layout).
661 void gemm_blocked_serial(const float *A,
662  const float *B,
663  const float *bias,
664  float *C,
665  int M, int N, int K)
666 {
667  // Ensure threads are initialized (auto-detects on first call)
668  (void)ck_get_num_threads();
669 
670  if (ck_strict_parity_enabled()) {
671  gemm_naive_serial_double(A, B, bias, C, M, N, K);
672  return;
673  }
674 
675  // Decode-time matvec (M=1) is extremely common and benefits from parallelism over N.
676  // Lower threshold to parallelize more ops; OpenMP overhead is ~1-2μs per barrier.
677  // For N*K >= 64K elements, parallel is worthwhile.
678  if (M == 1 && (size_t)N * (size_t)K >= 65536) {
679  gemm_nt_matvec_parallel(A, B, bias, C, N, K);
680  return;
681  }
682 
683  /*
684  * Use gemm_microkernel for large matrices - it uses MKL/oneDNN when available,
685  * which is substantially faster than our hand-written SIMD kernels.
686  * B is stored as [N x K] (transposed), so we pass B_transposed=1.
687  * Note: Use threshold of 32 to avoid numerical precision issues with small matrices.
688  */
689  if (M >= 32 && N >= 32 && K >= 32) {
690  gemm_microkernel(A, B, C, M, N, K, 1); // B_transposed=1
691  ck_gemm_add_bias(C, bias, M, N);
692  return;
693  }
694 #if defined(__AVX512F__)
695  const int block_size = 64;
696 #elif defined(__AVX__)
697  const int block_size = 32;
698 #else
699  const int block_size = 32;
700 #endif
701  for (int i = 0; i < M; i++) {
702  for (int j = 0; j < N; j++) {
703  C[i * N + j] = bias ? bias[j] : 0.0f;
704  }
705  }
706  for (int ii = 0; ii < M; ii += block_size) {
707  for (int jj = 0; jj < N; jj += block_size) {
708  for (int kk = 0; kk < K; kk += block_size) {
709  int i_end = ck_min(ii + block_size, M);
710  int j_end = ck_min(jj + block_size, N);
711  int k_end = ck_min(kk + block_size, K);
712 
713  for (int i = ii; i < i_end; i++) {
714  for (int j = jj; j < j_end; j++) {
715 #if defined(__AVX512F__)
716  __m512 sum_vec = _mm512_setzero_ps();
717  int k;
718  for (k = kk; k <= k_end - 16; k += 16) {
719  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
720  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
721  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
722  }
723  float partial_sum = _mm512_reduce_add_ps(sum_vec);
724  for (; k < k_end; k++) {
725  partial_sum += A[i * K + k] * B[j * K + k];
726  }
727 #elif defined(__AVX__)
728  __m256 sum_vec = _mm256_setzero_ps();
729  int k;
730  for (k = kk; k <= k_end - 8; k += 8) {
731  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
732  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
733  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
734  sum_vec = _mm256_add_ps(sum_vec, prod);
735  }
736  float partial_sum = hsum256_ps(sum_vec);
737  for (; k < k_end; k++) {
738  partial_sum += A[i * K + k] * B[j * K + k];
739  }
740 #else
741  float partial_sum = 0.0f;
742  for (int k = kk; k < k_end; k++) {
743  partial_sum += A[i * K + k] * B[j * K + k];
744  }
745 #endif
746  C[i * N + j] += partial_sum;
747  }
748  }
749  }
750  }
751  }
752 }
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
int ck_get_num_threads(void)
int ck_strict_parity_enabled(void)
static void gemm_naive_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:107
static int ck_min(int a, int b)
Definition: gemm_kernels.c:26
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:125
static void gemm_nn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:300
void gemm_nn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:317
static void gemm_nt_matvec_parallel(const float *A, const float *B, const float *bias, float *C, int N, int K)
Definition: gemm_kernels.c:61
void gemm_tn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:499
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:339
static void gemm_tn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:481
void gemm_avx512_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:149
void gemm_fine_grained_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:205
static void ck_gemm_add_bias(float *C, const float *bias, int M, int N)
Definition: gemm_kernels.c:28
void gemm_tn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:581
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661
void gemm_nn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:402
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:521
#define C(color)
Definition: show_config.c:39