← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_microkernel.c
Go to the documentation of this file.
1 /**
2  * @file gemm_microkernel.c
3  * @brief GEMM Microkernel - High-Performance Register-Blocked Matrix Multiplication
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * This file implements optimized GEMM microkernels with multiple backends:
15  *
16  * 1. USE_MKL: Intel MKL cblas_sgemm (best performance on Intel CPUs)
17  * 2. USE_ONEDNN: Intel oneDNN matmul primitive (Apache 2.0 licensed)
18  * 3. Native: Our own AVX-512/AVX2/AVX microkernels (no dependencies)
19  *
20  * Build with:
21  * make USE_MKL=1 # Use Intel MKL
22  * make USE_ONEDNN=1 # Use Intel oneDNN
23  * make # Use native kernels
24  *
25  * Layout: C[M,N] = A[M,K] @ B[K,N] (row-major)
26  */
27 
28 #include "ckernel_engine.h"
29 #include "cpu_features.h"
30 #include <string.h>
31 #include <stdlib.h>
32 #include <stdio.h>
33 
34 // =============================================================================
35 // Backend Selection: MKL > oneDNN > Native
36 // =============================================================================
37 
38 #if defined(USE_MKL)
39  #include <mkl.h>
40  #define GEMM_BACKEND "MKL"
41 #elif defined(USE_ONEDNN)
42  #include <dnnl.h>
43  #define GEMM_BACKEND "oneDNN"
44 #else
45  #define GEMM_BACKEND "Native"
46 #endif
47 
48 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
49 #include <immintrin.h>
50 #endif
51 
52 #ifdef _OPENMP
53 #include <omp.h>
54 #endif
55 
56 // =============================================================================
57 // MKL Backend Implementation
58 // =============================================================================
59 
60 #if defined(USE_MKL)
61 
62 void gemm_microkernel(
63  const float *A,
64  const float *B,
65  float *C,
66  int M, int N, int K,
67  int B_transposed
68 )
69 {
70  // MKL uses column-major by default, but CblasRowMajor handles row-major
71  // C = alpha * A @ B + beta * C
72  // For B_transposed: C = A @ B^T
73  cblas_sgemm(
74  CblasRowMajor,
75  CblasNoTrans,
76  B_transposed ? CblasTrans : CblasNoTrans,
77  M, N, K,
78  1.0f, // alpha
79  A, K, // lda (A is always [M,K] row-major)
80  B, B_transposed ? K : N, // ldb
81  0.0f, // beta
82  C, N // ldc
83  );
84 }
85 
86 // Stub implementations for blocked versions (MKL handles everything)
87 void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K) {
88  gemm_microkernel(A, B, C, M, N, K, 0);
89 }
90 void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K) {
91  gemm_microkernel(A, B, C, M, N, K, 0);
92 }
93 void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K) {
94  gemm_microkernel(A, B, C, M, N, K, 1);
95 }
96 
97 #elif defined(USE_ONEDNN)
98 
99 // =============================================================================
100 // oneDNN Backend Implementation
101 // =============================================================================
102 
103 // Global oneDNN engine and stream (initialized once)
104 static dnnl_engine_t g_engine = NULL;
105 static dnnl_stream_t g_stream = NULL;
106 
107 static void onednn_init(void) {
108  if (g_engine) return;
109  dnnl_engine_create(&g_engine, dnnl_cpu, 0);
110  dnnl_stream_create(&g_stream, g_engine, dnnl_stream_default_flags);
111 }
112 
113 void gemm_microkernel(
114  const float *A,
115  const float *B,
116  float *C,
117  int M, int N, int K,
118  int B_transposed
119 )
120 {
121  onednn_init();
122 
123  // Create memory descriptors for row-major matrices
124  dnnl_memory_desc_t a_md, b_md, c_md;
125  dnnl_dims_t a_dims = {M, K};
126  dnnl_dims_t b_dims = {K, N};
127  dnnl_dims_t c_dims = {M, N};
128  dnnl_dims_t a_strides = {K, 1};
129  dnnl_dims_t b_strides = {N, 1}; /* default: B is [K,N] row-major */
130  dnnl_dims_t c_strides = {N, 1};
131 
132  if (B_transposed) {
133  /*
134  * Our "B_transposed" convention means: caller stores B as [N,K] row-major,
135  * but wants C = A[M,K] @ B[N,K]^T => treat weights as [K,N].
136  *
137  * oneDNN matmul has no transpose flag, so represent B^T as a strided view:
138  * dims = [K, N]
139  * strides = [1, K] (offset(k,n) = k + n*K == B[n*K + k])
140  */
141  b_strides[0] = 1;
142  b_strides[1] = K;
143  }
144 
145  dnnl_memory_desc_create_with_strides(&a_md, 2, a_dims, dnnl_f32, a_strides);
146  dnnl_memory_desc_create_with_strides(&b_md, 2, b_dims, dnnl_f32, b_strides);
147  dnnl_memory_desc_create_with_strides(&c_md, 2, c_dims, dnnl_f32, c_strides);
148 
149  // Create matmul primitive descriptor
150  dnnl_primitive_desc_t matmul_pd;
151  dnnl_matmul_primitive_desc_create(&matmul_pd, g_engine, a_md, b_md, NULL, c_md, NULL);
152 
153  // Create primitive
154  dnnl_primitive_t matmul;
155  dnnl_primitive_create(&matmul, matmul_pd);
156 
157  // Create memory objects
158  dnnl_memory_t a_mem, b_mem, c_mem;
159  dnnl_memory_create(&a_mem, a_md, g_engine, (void*)A);
160  dnnl_memory_create(&b_mem, b_md, g_engine, (void*)B);
161  dnnl_memory_create(&c_mem, c_md, g_engine, (void*)C);
162 
163  // Execute
164  dnnl_exec_arg_t args[3] = {
165  {DNNL_ARG_SRC, a_mem},
166  {DNNL_ARG_WEIGHTS, b_mem},
167  {DNNL_ARG_DST, c_mem}
168  };
169  dnnl_primitive_execute(matmul, g_stream, 3, args);
170  dnnl_stream_wait(g_stream);
171 
172  // Cleanup
173  dnnl_primitive_destroy(matmul);
174  dnnl_primitive_desc_destroy(matmul_pd);
175  dnnl_memory_destroy(a_mem);
176  dnnl_memory_destroy(b_mem);
177  dnnl_memory_destroy(c_mem);
178  dnnl_memory_desc_destroy(a_md);
179  dnnl_memory_desc_destroy(b_md);
180  dnnl_memory_desc_destroy(c_md);
181 }
182 
183 void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K) {
184  gemm_microkernel(A, B, C, M, N, K, 0);
185 }
186 void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K) {
187  gemm_microkernel(A, B, C, M, N, K, 0);
188 }
189 void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K) {
190  gemm_microkernel(A, B, C, M, N, K, 1);
191 }
192 
193 #else
194 // =============================================================================
195 // Native Backend (our own AVX-512/AVX2/AVX kernels)
196 // =============================================================================
197 
198 // =============================================================================
199 // Microkernel Configuration
200 //
201 // MR/NR are fixed at compile time (microkernel register usage)
202 // MC/NC/KC are determined at runtime based on detected cache sizes
203 // =============================================================================
204 
205 #if defined(__AVX512F__)
206  // AVX-512: 6x32 microkernel (6 rows, 32 cols = 2 ZMM per row = 12 ZMM accumulators)
207  // 32 ZMM registers available - no spilling
208  #define MR_FIXED 6
209  #define NR_FIXED 32
210 #elif defined(__FMA__)
211  // AVX2+FMA: 6x16 microkernel - FMA hides latency, some spilling acceptable
212  #define MR_FIXED 6
213  #define NR_FIXED 16
214 #elif defined(__AVX__)
215  // AVX (no FMA): 4x16 microkernel to avoid register spilling
216  // Only 16 YMM registers: 8 accum + 2 B + 4 A + 2 temp = 16 (fits!)
217  #define MR_FIXED 4
218  #define NR_FIXED 16
219 #else
220  // Scalar fallback
221  #define MR_FIXED 4
222  #define NR_FIXED 4
223 #endif
224 
225 // These macros use runtime-detected values (initialized once at startup)
226 #define MR (MR_FIXED)
227 #define NR (NR_FIXED)
228 #define MC (get_gemm_params()->MC)
229 #define NC (get_gemm_params()->NC)
230 #define KC (get_gemm_params()->KC)
231 
232 // =============================================================================
233 // AVX-512 6x32 Microkernel - oneDNN style
234 //
235 // Computes: C[0:6, 0:32] += A[0:6, 0:K] @ B[0:K, 0:32]
236 //
237 // Register usage (32 ZMM registers available):
238 // - c0_lo, c0_hi, c1_lo, c1_hi, ... c5_hi: 12 accumulators (2 per row)
239 // - b_lo, b_hi: 2 registers for B row
240 // - a0-a5: 6 registers for A broadcasts
241 // - Remaining for prefetch and temp
242 // =============================================================================
243 
244 #if defined(__AVX512F__)
245 static inline void gemm_microkernel_6x32_avx512(
246  int K,
247  const float * __restrict__ A, int lda,
248  const float * __restrict__ B, int ldb,
249  float * __restrict__ C, int ldc,
250  int first_k
251 )
252 {
253  // 12 accumulators: 6 rows x 2 ZMM (32 floats) per row
254  __m512 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi;
255  __m512 c3_lo, c3_hi, c4_lo, c4_hi, c5_lo, c5_hi;
256 
257  if (first_k) {
258  c0_lo = _mm512_setzero_ps(); c0_hi = _mm512_setzero_ps();
259  c1_lo = _mm512_setzero_ps(); c1_hi = _mm512_setzero_ps();
260  c2_lo = _mm512_setzero_ps(); c2_hi = _mm512_setzero_ps();
261  c3_lo = _mm512_setzero_ps(); c3_hi = _mm512_setzero_ps();
262  c4_lo = _mm512_setzero_ps(); c4_hi = _mm512_setzero_ps();
263  c5_lo = _mm512_setzero_ps(); c5_hi = _mm512_setzero_ps();
264  } else {
265  c0_lo = _mm512_loadu_ps(&C[0 * ldc]); c0_hi = _mm512_loadu_ps(&C[0 * ldc + 16]);
266  c1_lo = _mm512_loadu_ps(&C[1 * ldc]); c1_hi = _mm512_loadu_ps(&C[1 * ldc + 16]);
267  c2_lo = _mm512_loadu_ps(&C[2 * ldc]); c2_hi = _mm512_loadu_ps(&C[2 * ldc + 16]);
268  c3_lo = _mm512_loadu_ps(&C[3 * ldc]); c3_hi = _mm512_loadu_ps(&C[3 * ldc + 16]);
269  c4_lo = _mm512_loadu_ps(&C[4 * ldc]); c4_hi = _mm512_loadu_ps(&C[4 * ldc + 16]);
270  c5_lo = _mm512_loadu_ps(&C[5 * ldc]); c5_hi = _mm512_loadu_ps(&C[5 * ldc + 16]);
271  }
272 
273  // Prefetch first cache lines
274  _mm_prefetch((const char*)&B[0], _MM_HINT_T0);
275  _mm_prefetch((const char*)&B[64], _MM_HINT_T0);
276 
277  // Main K loop - unrolled by 8 with software pipelining for better ILP
278  int k = 0;
279  for (; k <= K - 8; k += 8) {
280  // Prefetch ahead - 16 rows ahead for L1, 32 for L2
281  _mm_prefetch((const char*)&B[(k + 16) * ldb], _MM_HINT_T0);
282  _mm_prefetch((const char*)&B[(k + 16) * ldb + 64], _MM_HINT_T0);
283  _mm_prefetch((const char*)&B[(k + 32) * ldb], _MM_HINT_T1);
284 
285  // Software pipelining: preload first B row
286  __m512 b_lo_next = _mm512_loadu_ps(&B[k * ldb]);
287  __m512 b_hi_next = _mm512_loadu_ps(&B[k * ldb + 16]);
288 
289  #define AVX512_ITER(koff) { \
290  __m512 b_lo = b_lo_next; \
291  __m512 b_hi = b_hi_next; \
292  if ((koff) < 7) { \
293  b_lo_next = _mm512_loadu_ps(&B[(k + (koff) + 1) * ldb]); \
294  b_hi_next = _mm512_loadu_ps(&B[(k + (koff) + 1) * ldb + 16]); \
295  } \
296  __m512 a0 = _mm512_set1_ps(A[0 * lda + k + (koff)]); \
297  __m512 a1 = _mm512_set1_ps(A[1 * lda + k + (koff)]); \
298  __m512 a2 = _mm512_set1_ps(A[2 * lda + k + (koff)]); \
299  __m512 a3 = _mm512_set1_ps(A[3 * lda + k + (koff)]); \
300  __m512 a4 = _mm512_set1_ps(A[4 * lda + k + (koff)]); \
301  __m512 a5 = _mm512_set1_ps(A[5 * lda + k + (koff)]); \
302  c0_lo = _mm512_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm512_fmadd_ps(a0, b_hi, c0_hi); \
303  c1_lo = _mm512_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm512_fmadd_ps(a1, b_hi, c1_hi); \
304  c2_lo = _mm512_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm512_fmadd_ps(a2, b_hi, c2_hi); \
305  c3_lo = _mm512_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm512_fmadd_ps(a3, b_hi, c3_hi); \
306  c4_lo = _mm512_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm512_fmadd_ps(a4, b_hi, c4_hi); \
307  c5_lo = _mm512_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm512_fmadd_ps(a5, b_hi, c5_hi); \
308  }
309 
310  AVX512_ITER(0);
311  AVX512_ITER(1);
312  AVX512_ITER(2);
313  AVX512_ITER(3);
314  AVX512_ITER(4);
315  AVX512_ITER(5);
316  AVX512_ITER(6);
317  AVX512_ITER(7);
318 
319  #undef AVX512_ITER
320  }
321 
322  // Handle K % 8 remainder with 4-unroll
323  for (; k <= K - 4; k += 4) {
324  #define AVX512_ITER4(koff) { \
325  __m512 b_lo = _mm512_loadu_ps(&B[(k + koff) * ldb]); \
326  __m512 b_hi = _mm512_loadu_ps(&B[(k + koff) * ldb + 16]); \
327  __m512 a0 = _mm512_set1_ps(A[0 * lda + k + koff]); \
328  __m512 a1 = _mm512_set1_ps(A[1 * lda + k + koff]); \
329  __m512 a2 = _mm512_set1_ps(A[2 * lda + k + koff]); \
330  __m512 a3 = _mm512_set1_ps(A[3 * lda + k + koff]); \
331  __m512 a4 = _mm512_set1_ps(A[4 * lda + k + koff]); \
332  __m512 a5 = _mm512_set1_ps(A[5 * lda + k + koff]); \
333  c0_lo = _mm512_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm512_fmadd_ps(a0, b_hi, c0_hi); \
334  c1_lo = _mm512_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm512_fmadd_ps(a1, b_hi, c1_hi); \
335  c2_lo = _mm512_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm512_fmadd_ps(a2, b_hi, c2_hi); \
336  c3_lo = _mm512_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm512_fmadd_ps(a3, b_hi, c3_hi); \
337  c4_lo = _mm512_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm512_fmadd_ps(a4, b_hi, c4_hi); \
338  c5_lo = _mm512_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm512_fmadd_ps(a5, b_hi, c5_hi); \
339  }
340  AVX512_ITER4(0);
341  AVX512_ITER4(1);
342  AVX512_ITER4(2);
343  AVX512_ITER4(3);
344  #undef AVX512_ITER4
345  }
346 
347  // Handle remaining K
348  for (; k < K; k++) {
349  __m512 b_lo = _mm512_loadu_ps(&B[k * ldb]);
350  __m512 b_hi = _mm512_loadu_ps(&B[k * ldb + 16]);
351 
352  c0_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[0 * lda + k]), b_lo, c0_lo);
353  c0_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[0 * lda + k]), b_hi, c0_hi);
354  c1_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[1 * lda + k]), b_lo, c1_lo);
355  c1_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[1 * lda + k]), b_hi, c1_hi);
356  c2_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[2 * lda + k]), b_lo, c2_lo);
357  c2_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[2 * lda + k]), b_hi, c2_hi);
358  c3_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[3 * lda + k]), b_lo, c3_lo);
359  c3_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[3 * lda + k]), b_hi, c3_hi);
360  c4_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[4 * lda + k]), b_lo, c4_lo);
361  c4_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[4 * lda + k]), b_hi, c4_hi);
362  c5_lo = _mm512_fmadd_ps(_mm512_set1_ps(A[5 * lda + k]), b_lo, c5_lo);
363  c5_hi = _mm512_fmadd_ps(_mm512_set1_ps(A[5 * lda + k]), b_hi, c5_hi);
364  }
365 
366  // Store results
367  _mm512_storeu_ps(&C[0 * ldc], c0_lo); _mm512_storeu_ps(&C[0 * ldc + 16], c0_hi);
368  _mm512_storeu_ps(&C[1 * ldc], c1_lo); _mm512_storeu_ps(&C[1 * ldc + 16], c1_hi);
369  _mm512_storeu_ps(&C[2 * ldc], c2_lo); _mm512_storeu_ps(&C[2 * ldc + 16], c2_hi);
370  _mm512_storeu_ps(&C[3 * ldc], c3_lo); _mm512_storeu_ps(&C[3 * ldc + 16], c3_hi);
371  _mm512_storeu_ps(&C[4 * ldc], c4_lo); _mm512_storeu_ps(&C[4 * ldc + 16], c4_hi);
372  _mm512_storeu_ps(&C[5 * ldc], c5_lo); _mm512_storeu_ps(&C[5 * ldc + 16], c5_hi);
373 }
374 
375 // Packed version for large matrices
376 static inline void gemm_microkernel_6x32_packed_avx512(
377  int K,
378  const float * __restrict__ Ap, // Packed A: [MR, K] contiguous
379  const float * __restrict__ Bp, // Packed B: [K, NR] contiguous
380  float * __restrict__ C, int ldc,
381  int first_k
382 )
383 {
384  __m512 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi;
385  __m512 c3_lo, c3_hi, c4_lo, c4_hi, c5_lo, c5_hi;
386 
387  if (first_k) {
388  c0_lo = _mm512_setzero_ps(); c0_hi = _mm512_setzero_ps();
389  c1_lo = _mm512_setzero_ps(); c1_hi = _mm512_setzero_ps();
390  c2_lo = _mm512_setzero_ps(); c2_hi = _mm512_setzero_ps();
391  c3_lo = _mm512_setzero_ps(); c3_hi = _mm512_setzero_ps();
392  c4_lo = _mm512_setzero_ps(); c4_hi = _mm512_setzero_ps();
393  c5_lo = _mm512_setzero_ps(); c5_hi = _mm512_setzero_ps();
394  } else {
395  c0_lo = _mm512_loadu_ps(&C[0 * ldc]); c0_hi = _mm512_loadu_ps(&C[0 * ldc + 16]);
396  c1_lo = _mm512_loadu_ps(&C[1 * ldc]); c1_hi = _mm512_loadu_ps(&C[1 * ldc + 16]);
397  c2_lo = _mm512_loadu_ps(&C[2 * ldc]); c2_hi = _mm512_loadu_ps(&C[2 * ldc + 16]);
398  c3_lo = _mm512_loadu_ps(&C[3 * ldc]); c3_hi = _mm512_loadu_ps(&C[3 * ldc + 16]);
399  c4_lo = _mm512_loadu_ps(&C[4 * ldc]); c4_hi = _mm512_loadu_ps(&C[4 * ldc + 16]);
400  c5_lo = _mm512_loadu_ps(&C[5 * ldc]); c5_hi = _mm512_loadu_ps(&C[5 * ldc + 16]);
401  }
402 
403  // Packed B is contiguous: B[k, 0:32] at Bp[k * 32]
404  _mm_prefetch((const char*)Bp, _MM_HINT_T0);
405  _mm_prefetch((const char*)(Bp + 16), _MM_HINT_T0);
406 
407  int k = 0;
408  for (; k <= K - 4; k += 4) {
409  _mm_prefetch((const char*)(Bp + (k + 8) * NR), _MM_HINT_T0);
410  _mm_prefetch((const char*)(Bp + (k + 8) * NR + 16), _MM_HINT_T0);
411 
412  #define PACKED_ITER(koff) { \
413  __m512 b_lo = _mm512_load_ps(&Bp[(k + koff) * NR]); \
414  __m512 b_hi = _mm512_load_ps(&Bp[(k + koff) * NR + 16]); \
415  __m512 a0 = _mm512_set1_ps(Ap[0 * K + k + koff]); \
416  __m512 a1 = _mm512_set1_ps(Ap[1 * K + k + koff]); \
417  __m512 a2 = _mm512_set1_ps(Ap[2 * K + k + koff]); \
418  __m512 a3 = _mm512_set1_ps(Ap[3 * K + k + koff]); \
419  __m512 a4 = _mm512_set1_ps(Ap[4 * K + k + koff]); \
420  __m512 a5 = _mm512_set1_ps(Ap[5 * K + k + koff]); \
421  c0_lo = _mm512_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm512_fmadd_ps(a0, b_hi, c0_hi); \
422  c1_lo = _mm512_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm512_fmadd_ps(a1, b_hi, c1_hi); \
423  c2_lo = _mm512_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm512_fmadd_ps(a2, b_hi, c2_hi); \
424  c3_lo = _mm512_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm512_fmadd_ps(a3, b_hi, c3_hi); \
425  c4_lo = _mm512_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm512_fmadd_ps(a4, b_hi, c4_hi); \
426  c5_lo = _mm512_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm512_fmadd_ps(a5, b_hi, c5_hi); \
427  }
428 
429  PACKED_ITER(0);
430  PACKED_ITER(1);
431  PACKED_ITER(2);
432  PACKED_ITER(3);
433 
434  #undef PACKED_ITER
435  }
436 
437  for (; k < K; k++) {
438  __m512 b_lo = _mm512_load_ps(&Bp[k * NR]);
439  __m512 b_hi = _mm512_load_ps(&Bp[k * NR + 16]);
440 
441  c0_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[0 * K + k]), b_lo, c0_lo);
442  c0_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[0 * K + k]), b_hi, c0_hi);
443  c1_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[1 * K + k]), b_lo, c1_lo);
444  c1_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[1 * K + k]), b_hi, c1_hi);
445  c2_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[2 * K + k]), b_lo, c2_lo);
446  c2_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[2 * K + k]), b_hi, c2_hi);
447  c3_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[3 * K + k]), b_lo, c3_lo);
448  c3_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[3 * K + k]), b_hi, c3_hi);
449  c4_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[4 * K + k]), b_lo, c4_lo);
450  c4_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[4 * K + k]), b_hi, c4_hi);
451  c5_lo = _mm512_fmadd_ps(_mm512_set1_ps(Ap[5 * K + k]), b_lo, c5_lo);
452  c5_hi = _mm512_fmadd_ps(_mm512_set1_ps(Ap[5 * K + k]), b_hi, c5_hi);
453  }
454 
455  _mm512_storeu_ps(&C[0 * ldc], c0_lo); _mm512_storeu_ps(&C[0 * ldc + 16], c0_hi);
456  _mm512_storeu_ps(&C[1 * ldc], c1_lo); _mm512_storeu_ps(&C[1 * ldc + 16], c1_hi);
457  _mm512_storeu_ps(&C[2 * ldc], c2_lo); _mm512_storeu_ps(&C[2 * ldc + 16], c2_hi);
458  _mm512_storeu_ps(&C[3 * ldc], c3_lo); _mm512_storeu_ps(&C[3 * ldc + 16], c3_hi);
459  _mm512_storeu_ps(&C[4 * ldc], c4_lo); _mm512_storeu_ps(&C[4 * ldc + 16], c4_hi);
460  _mm512_storeu_ps(&C[5 * ldc], c5_lo); _mm512_storeu_ps(&C[5 * ldc + 16], c5_hi);
461 }
462 #endif // __AVX512F__
463 
464 // =============================================================================
465 // AVX/AVX2 6x16 Microkernel with FMA
466 //
467 // KEY FIX: Use _mm256_fmadd_ps instead of separate mul+add
468 // FMA fuses multiply-add into single instruction: c = a*b + c
469 // This gives ~2x throughput improvement on FMA-capable CPUs
470 // =============================================================================
471 
472 #if defined(__AVX__)
473 static inline void gemm_microkernel_6x16_avx(
474  int K,
475  const float * __restrict__ A, int lda,
476  const float * __restrict__ B, int ldb,
477  float * __restrict__ C, int ldc,
478  int first_k
479 )
480 {
481  // 12 accumulators: 6 rows x 2 YMM (16 floats) per row
482  __m256 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi;
483  __m256 c3_lo, c3_hi, c4_lo, c4_hi, c5_lo, c5_hi;
484 
485  if (first_k) {
486  c0_lo = _mm256_setzero_ps(); c0_hi = _mm256_setzero_ps();
487  c1_lo = _mm256_setzero_ps(); c1_hi = _mm256_setzero_ps();
488  c2_lo = _mm256_setzero_ps(); c2_hi = _mm256_setzero_ps();
489  c3_lo = _mm256_setzero_ps(); c3_hi = _mm256_setzero_ps();
490  c4_lo = _mm256_setzero_ps(); c4_hi = _mm256_setzero_ps();
491  c5_lo = _mm256_setzero_ps(); c5_hi = _mm256_setzero_ps();
492  } else {
493  c0_lo = _mm256_loadu_ps(&C[0 * ldc]); c0_hi = _mm256_loadu_ps(&C[0 * ldc + 8]);
494  c1_lo = _mm256_loadu_ps(&C[1 * ldc]); c1_hi = _mm256_loadu_ps(&C[1 * ldc + 8]);
495  c2_lo = _mm256_loadu_ps(&C[2 * ldc]); c2_hi = _mm256_loadu_ps(&C[2 * ldc + 8]);
496  c3_lo = _mm256_loadu_ps(&C[3 * ldc]); c3_hi = _mm256_loadu_ps(&C[3 * ldc + 8]);
497  c4_lo = _mm256_loadu_ps(&C[4 * ldc]); c4_hi = _mm256_loadu_ps(&C[4 * ldc + 8]);
498  c5_lo = _mm256_loadu_ps(&C[5 * ldc]); c5_hi = _mm256_loadu_ps(&C[5 * ldc + 8]);
499  }
500 
501  // Prefetch first cache lines of B
502  _mm_prefetch((const char*)B, _MM_HINT_T0);
503  _mm_prefetch((const char*)(B + 32), _MM_HINT_T0);
504 
505  // Main K loop - unrolled by 8 for better ILP and prefetch hiding
506  int k = 0;
507 
508 #if defined(__FMA__)
509  // FMA path - uses fused multiply-add (single instruction)
510  for (; k <= K - 8; k += 8) {
511  // Prefetch ahead - 16 rows ahead for L1
512  _mm_prefetch((const char*)&B[(k + 16) * ldb], _MM_HINT_T0);
513  _mm_prefetch((const char*)&B[(k + 16) * ldb + 32], _MM_HINT_T0);
514  _mm_prefetch((const char*)&B[(k + 17) * ldb], _MM_HINT_T0);
515 
516  // Software pipelining: load B for iteration 0 before the loop body
517  __m256 b_lo_next = _mm256_loadu_ps(&B[k * ldb]);
518  __m256 b_hi_next = _mm256_loadu_ps(&B[k * ldb + 8]);
519 
520  #define FMA_ITER(koff) { \
521  __m256 b_lo = b_lo_next; \
522  __m256 b_hi = b_hi_next; \
523  if ((koff) < 7) { \
524  b_lo_next = _mm256_loadu_ps(&B[(k + (koff) + 1) * ldb]); \
525  b_hi_next = _mm256_loadu_ps(&B[(k + (koff) + 1) * ldb + 8]); \
526  } \
527  __m256 a0 = _mm256_set1_ps(A[0 * lda + k + (koff)]); \
528  __m256 a1 = _mm256_set1_ps(A[1 * lda + k + (koff)]); \
529  __m256 a2 = _mm256_set1_ps(A[2 * lda + k + (koff)]); \
530  __m256 a3 = _mm256_set1_ps(A[3 * lda + k + (koff)]); \
531  __m256 a4 = _mm256_set1_ps(A[4 * lda + k + (koff)]); \
532  __m256 a5 = _mm256_set1_ps(A[5 * lda + k + (koff)]); \
533  c0_lo = _mm256_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm256_fmadd_ps(a0, b_hi, c0_hi); \
534  c1_lo = _mm256_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm256_fmadd_ps(a1, b_hi, c1_hi); \
535  c2_lo = _mm256_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm256_fmadd_ps(a2, b_hi, c2_hi); \
536  c3_lo = _mm256_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm256_fmadd_ps(a3, b_hi, c3_hi); \
537  c4_lo = _mm256_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm256_fmadd_ps(a4, b_hi, c4_hi); \
538  c5_lo = _mm256_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm256_fmadd_ps(a5, b_hi, c5_hi); \
539  }
540 
541  FMA_ITER(0);
542  FMA_ITER(1);
543  FMA_ITER(2);
544  FMA_ITER(3);
545  FMA_ITER(4);
546  FMA_ITER(5);
547  FMA_ITER(6);
548  FMA_ITER(7);
549 
550  #undef FMA_ITER
551  }
552 
553  // Handle remaining K with FMA
554  for (; k < K; k++) {
555  __m256 b_lo = _mm256_loadu_ps(&B[k * ldb]);
556  __m256 b_hi = _mm256_loadu_ps(&B[k * ldb + 8]);
557 
558  __m256 a0 = _mm256_set1_ps(A[0 * lda + k]);
559  __m256 a1 = _mm256_set1_ps(A[1 * lda + k]);
560  __m256 a2 = _mm256_set1_ps(A[2 * lda + k]);
561  __m256 a3 = _mm256_set1_ps(A[3 * lda + k]);
562  __m256 a4 = _mm256_set1_ps(A[4 * lda + k]);
563  __m256 a5 = _mm256_set1_ps(A[5 * lda + k]);
564 
565  c0_lo = _mm256_fmadd_ps(a0, b_lo, c0_lo); c0_hi = _mm256_fmadd_ps(a0, b_hi, c0_hi);
566  c1_lo = _mm256_fmadd_ps(a1, b_lo, c1_lo); c1_hi = _mm256_fmadd_ps(a1, b_hi, c1_hi);
567  c2_lo = _mm256_fmadd_ps(a2, b_lo, c2_lo); c2_hi = _mm256_fmadd_ps(a2, b_hi, c2_hi);
568  c3_lo = _mm256_fmadd_ps(a3, b_lo, c3_lo); c3_hi = _mm256_fmadd_ps(a3, b_hi, c3_hi);
569  c4_lo = _mm256_fmadd_ps(a4, b_lo, c4_lo); c4_hi = _mm256_fmadd_ps(a4, b_hi, c4_hi);
570  c5_lo = _mm256_fmadd_ps(a5, b_lo, c5_lo); c5_hi = _mm256_fmadd_ps(a5, b_hi, c5_hi);
571  }
572 #else
573  // Non-FMA fallback (older CPUs without FMA)
574  for (; k <= K - 4; k += 4) {
575  _mm_prefetch((const char*)&B[(k + 8) * ldb], _MM_HINT_T0);
576 
577  #define AVX_ITER(koff) { \
578  __m256 b_lo = _mm256_loadu_ps(&B[(k + koff) * ldb]); \
579  __m256 b_hi = _mm256_loadu_ps(&B[(k + koff) * ldb + 8]); \
580  __m256 a0 = _mm256_set1_ps(A[0 * lda + k + koff]); \
581  __m256 a1 = _mm256_set1_ps(A[1 * lda + k + koff]); \
582  __m256 a2 = _mm256_set1_ps(A[2 * lda + k + koff]); \
583  __m256 a3 = _mm256_set1_ps(A[3 * lda + k + koff]); \
584  __m256 a4 = _mm256_set1_ps(A[4 * lda + k + koff]); \
585  __m256 a5 = _mm256_set1_ps(A[5 * lda + k + koff]); \
586  c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo)); \
587  c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi)); \
588  c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo)); \
589  c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi)); \
590  c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo)); \
591  c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi)); \
592  c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo)); \
593  c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi)); \
594  c4_lo = _mm256_add_ps(c4_lo, _mm256_mul_ps(a4, b_lo)); \
595  c4_hi = _mm256_add_ps(c4_hi, _mm256_mul_ps(a4, b_hi)); \
596  c5_lo = _mm256_add_ps(c5_lo, _mm256_mul_ps(a5, b_lo)); \
597  c5_hi = _mm256_add_ps(c5_hi, _mm256_mul_ps(a5, b_hi)); \
598  }
599 
600  AVX_ITER(0);
601  AVX_ITER(1);
602  AVX_ITER(2);
603  AVX_ITER(3);
604 
605  #undef AVX_ITER
606  }
607 
608  for (; k < K; k++) {
609  __m256 b_lo = _mm256_loadu_ps(&B[k * ldb]);
610  __m256 b_hi = _mm256_loadu_ps(&B[k * ldb + 8]);
611 
612  __m256 a0 = _mm256_set1_ps(A[0 * lda + k]);
613  __m256 a1 = _mm256_set1_ps(A[1 * lda + k]);
614  __m256 a2 = _mm256_set1_ps(A[2 * lda + k]);
615  __m256 a3 = _mm256_set1_ps(A[3 * lda + k]);
616  __m256 a4 = _mm256_set1_ps(A[4 * lda + k]);
617  __m256 a5 = _mm256_set1_ps(A[5 * lda + k]);
618 
619  c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo));
620  c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi));
621  c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo));
622  c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi));
623  c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo));
624  c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi));
625  c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo));
626  c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi));
627  c4_lo = _mm256_add_ps(c4_lo, _mm256_mul_ps(a4, b_lo));
628  c4_hi = _mm256_add_ps(c4_hi, _mm256_mul_ps(a4, b_hi));
629  c5_lo = _mm256_add_ps(c5_lo, _mm256_mul_ps(a5, b_lo));
630  c5_hi = _mm256_add_ps(c5_hi, _mm256_mul_ps(a5, b_hi));
631  }
632 #endif
633 
634  _mm256_storeu_ps(&C[0 * ldc], c0_lo); _mm256_storeu_ps(&C[0 * ldc + 8], c0_hi);
635  _mm256_storeu_ps(&C[1 * ldc], c1_lo); _mm256_storeu_ps(&C[1 * ldc + 8], c1_hi);
636  _mm256_storeu_ps(&C[2 * ldc], c2_lo); _mm256_storeu_ps(&C[2 * ldc + 8], c2_hi);
637  _mm256_storeu_ps(&C[3 * ldc], c3_lo); _mm256_storeu_ps(&C[3 * ldc + 8], c3_hi);
638  _mm256_storeu_ps(&C[4 * ldc], c4_lo); _mm256_storeu_ps(&C[4 * ldc + 8], c4_hi);
639  _mm256_storeu_ps(&C[5 * ldc], c5_lo); _mm256_storeu_ps(&C[5 * ldc + 8], c5_hi);
640 }
641 #endif
642 
643 // =============================================================================
644 // AVX 4x16 Microkernel (for AVX-only CPUs without FMA)
645 //
646 // This smaller tile avoids register spilling on CPUs with only 16 YMM registers.
647 // Register allocation: 8 accumulators + 2 B + 4 A + 2 temp = 16 registers
648 // =============================================================================
649 
650 #if defined(__AVX__) && !defined(__FMA__)
651 static inline void gemm_microkernel_4x16_avx(
652  int K,
653  const float * __restrict__ A, int lda,
654  const float * __restrict__ B, int ldb,
655  float * __restrict__ C, int ldc,
656  int first_k
657 )
658 {
659  // 8 accumulators: 4 rows x 2 YMM (16 floats) per row
660  __m256 c0_lo, c0_hi, c1_lo, c1_hi, c2_lo, c2_hi, c3_lo, c3_hi;
661 
662  if (first_k) {
663  c0_lo = _mm256_setzero_ps(); c0_hi = _mm256_setzero_ps();
664  c1_lo = _mm256_setzero_ps(); c1_hi = _mm256_setzero_ps();
665  c2_lo = _mm256_setzero_ps(); c2_hi = _mm256_setzero_ps();
666  c3_lo = _mm256_setzero_ps(); c3_hi = _mm256_setzero_ps();
667  } else {
668  c0_lo = _mm256_loadu_ps(&C[0 * ldc]); c0_hi = _mm256_loadu_ps(&C[0 * ldc + 8]);
669  c1_lo = _mm256_loadu_ps(&C[1 * ldc]); c1_hi = _mm256_loadu_ps(&C[1 * ldc + 8]);
670  c2_lo = _mm256_loadu_ps(&C[2 * ldc]); c2_hi = _mm256_loadu_ps(&C[2 * ldc + 8]);
671  c3_lo = _mm256_loadu_ps(&C[3 * ldc]); c3_hi = _mm256_loadu_ps(&C[3 * ldc + 8]);
672  }
673 
674  _mm_prefetch((const char*)B, _MM_HINT_T0);
675 
676  // K loop - unrolled by 4 for better ILP
677  int k = 0;
678  for (; k <= K - 4; k += 4) {
679  _mm_prefetch((const char*)&B[(k + 8) * ldb], _MM_HINT_T0);
680 
681  #define AVX4_ITER(koff) { \
682  __m256 b_lo = _mm256_loadu_ps(&B[(k + koff) * ldb]); \
683  __m256 b_hi = _mm256_loadu_ps(&B[(k + koff) * ldb + 8]); \
684  __m256 a0 = _mm256_set1_ps(A[0 * lda + k + koff]); \
685  __m256 a1 = _mm256_set1_ps(A[1 * lda + k + koff]); \
686  __m256 a2 = _mm256_set1_ps(A[2 * lda + k + koff]); \
687  __m256 a3 = _mm256_set1_ps(A[3 * lda + k + koff]); \
688  c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo)); \
689  c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi)); \
690  c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo)); \
691  c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi)); \
692  c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo)); \
693  c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi)); \
694  c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo)); \
695  c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi)); \
696  }
697 
698  AVX4_ITER(0);
699  AVX4_ITER(1);
700  AVX4_ITER(2);
701  AVX4_ITER(3);
702 
703  #undef AVX4_ITER
704  }
705 
706  // Handle remaining K
707  for (; k < K; k++) {
708  __m256 b_lo = _mm256_loadu_ps(&B[k * ldb]);
709  __m256 b_hi = _mm256_loadu_ps(&B[k * ldb + 8]);
710 
711  __m256 a0 = _mm256_set1_ps(A[0 * lda + k]);
712  __m256 a1 = _mm256_set1_ps(A[1 * lda + k]);
713  __m256 a2 = _mm256_set1_ps(A[2 * lda + k]);
714  __m256 a3 = _mm256_set1_ps(A[3 * lda + k]);
715 
716  c0_lo = _mm256_add_ps(c0_lo, _mm256_mul_ps(a0, b_lo));
717  c0_hi = _mm256_add_ps(c0_hi, _mm256_mul_ps(a0, b_hi));
718  c1_lo = _mm256_add_ps(c1_lo, _mm256_mul_ps(a1, b_lo));
719  c1_hi = _mm256_add_ps(c1_hi, _mm256_mul_ps(a1, b_hi));
720  c2_lo = _mm256_add_ps(c2_lo, _mm256_mul_ps(a2, b_lo));
721  c2_hi = _mm256_add_ps(c2_hi, _mm256_mul_ps(a2, b_hi));
722  c3_lo = _mm256_add_ps(c3_lo, _mm256_mul_ps(a3, b_lo));
723  c3_hi = _mm256_add_ps(c3_hi, _mm256_mul_ps(a3, b_hi));
724  }
725 
726  _mm256_storeu_ps(&C[0 * ldc], c0_lo); _mm256_storeu_ps(&C[0 * ldc + 8], c0_hi);
727  _mm256_storeu_ps(&C[1 * ldc], c1_lo); _mm256_storeu_ps(&C[1 * ldc + 8], c1_hi);
728  _mm256_storeu_ps(&C[2 * ldc], c2_lo); _mm256_storeu_ps(&C[2 * ldc + 8], c2_hi);
729  _mm256_storeu_ps(&C[3 * ldc], c3_lo); _mm256_storeu_ps(&C[3 * ldc + 8], c3_hi);
730 }
731 #endif
732 
733 // =============================================================================
734 // Edge case handler for non-MRxNR aligned tiles
735 // =============================================================================
736 
738  int m, int n, int K,
739  const float *A, int lda,
740  const float *B, int ldb,
741  float *C, int ldc,
742  int first_k
743 )
744 {
745  for (int i = 0; i < m; i++) {
746  for (int j = 0; j < n; j++) {
747  float sum = first_k ? 0.0f : C[i * ldc + j];
748  for (int k = 0; k < K; k++) {
749  sum += A[i * lda + k] * B[k * ldb + j];
750  }
751  C[i * ldc + j] = sum;
752  }
753  }
754 }
755 
756 // =============================================================================
757 // Matrix Packing Functions - Parallel for large matrices
758 // =============================================================================
759 
760 // Pack A panel: A[m0:m0+mc, k0:k0+kc] -> Ap[mc, kc] in row-panel format
761 static void pack_a_panel(
762  const float *A, int lda,
763  float *Ap,
764  int mc, int kc, int mr
765 )
766 {
767  #pragma omp parallel for schedule(static) if(mc > 64)
768  for (int i = 0; i < mc; i += mr) {
769  int rows = (i + mr <= mc) ? mr : (mc - i);
770  float *Ap_panel = &Ap[(i / mr) * mr * kc];
771 
772  for (int p = 0; p < rows; p++) {
773  const float *A_row = &A[(i + p) * lda];
774  float *Ap_row = &Ap_panel[p * kc];
775 
776  // Vectorized copy
777  int k = 0;
778 #if defined(__AVX__)
779  for (; k <= kc - 8; k += 8) {
780  _mm256_storeu_ps(&Ap_row[k], _mm256_loadu_ps(&A_row[k]));
781  }
782 #endif
783  for (; k < kc; k++) {
784  Ap_row[k] = A_row[k];
785  }
786  }
787  // Zero pad if partial panel
788  for (int p = rows; p < mr; p++) {
789  memset(&Ap_panel[p * kc], 0, kc * sizeof(float));
790  }
791  }
792 }
793 
794 // Pack B panel: B[k0:k0+kc, n0:n0+nc] -> Bp[kc, nc] in column-panel format
795 static void pack_b_panel(
796  const float *B, int ldb,
797  float *Bp,
798  int kc, int nc, int nr
799 )
800 {
801  #pragma omp parallel for schedule(static) if(nc > 128)
802  for (int j = 0; j < nc; j += nr) {
803  int cols = (j + nr <= nc) ? nr : (nc - j);
804  float *Bp_panel = &Bp[(j / nr) * nr * kc];
805 
806  for (int k = 0; k < kc; k++) {
807  const float *B_row = &B[k * ldb + j];
808  float *Bp_row = &Bp_panel[k * nr];
809 
810  // Copy cols and zero-pad
811  int c = 0;
812 #if defined(__AVX512F__)
813  for (; c <= cols - 16; c += 16) {
814  _mm512_store_ps(&Bp_row[c], _mm512_loadu_ps(&B_row[c]));
815  }
816 #elif defined(__AVX__)
817  for (; c <= cols - 8; c += 8) {
818  _mm256_store_ps(&Bp_row[c], _mm256_loadu_ps(&B_row[c]));
819  }
820 #endif
821  for (; c < cols; c++) {
822  Bp_row[c] = B_row[c];
823  }
824  for (; c < nr; c++) {
825  Bp_row[c] = 0.0f;
826  }
827  }
828  }
829 }
830 
831 // =============================================================================
832 // High-Performance GEMM with 2D Threading and Packing
833 //
834 // This is the main entry point for large matrices. Uses:
835 // 1. 2D thread partitioning (across M and N blocks)
836 // 2. Parallel matrix packing
837 // 3. Optimized microkernels
838 // =============================================================================
839 
841  const float *A,
842  const float *B,
843  float *C,
844  int M, int N, int K
845 )
846 {
847  // Use tile-parallel blocked version - scales better on many-core systems
848  gemm_microkernel_blocked(A, B, C, M, N, K);
849 }
850 
851 // =============================================================================
852 // Cache-Blocked GEMM with 2D Threading
853 //
854 // KEY FIX: Use 2D parallelization across both M and N tile dimensions.
855 // For 48-core Xeon, we need at least 48 parallel tasks. With 1024x1024:
856 // - M_tiles = ceil(1024 / MR) = ~170 tiles
857 // - N_tiles = ceil(1024 / NR) = ~32 tiles (for NR=32)
858 // - Total = 5440 tiles - excellent parallelism!
859 // =============================================================================
860 
861 // Sequential version for small matrices (avoids OpenMP overhead)
863  const float *A,
864  const float *B,
865  float *C,
866  int M, int N, int K
867 )
868 {
869  // Zero output
870  for (int i = 0; i < M; i++) {
871  memset(&C[i * N], 0, N * sizeof(float));
872  }
873 
874  const int mr = MR;
875  const int nr = NR;
876 
877  // Block over K
878  for (int k0 = 0; k0 < K; k0 += KC) {
879  int kb = (k0 + KC <= K) ? KC : (K - k0);
880  int first_k = (k0 == 0);
881 
882  // Loop over tiles
883  for (int m0 = 0; m0 < M; m0 += mr) {
884  int mr_actual = (m0 + mr <= M) ? mr : (M - m0);
885 
886  for (int n0 = 0; n0 < N; n0 += nr) {
887  int nr_actual = (n0 + nr <= N) ? nr : (N - n0);
888 
889  const float *A_tile = &A[m0 * K + k0];
890  const float *B_tile = &B[k0 * N + n0];
891  float *C_tile = &C[m0 * N + n0];
892 
893  if (mr_actual == mr && nr_actual == nr) {
894 #if defined(__AVX512F__)
895  gemm_microkernel_6x32_avx512(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
896 #elif defined(__FMA__)
897  gemm_microkernel_6x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
898 #elif defined(__AVX__)
899  gemm_microkernel_4x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
900 #else
901  gemm_microkernel_edge(mr_actual, nr_actual, kb, A_tile, K, B_tile, N, C_tile, N, first_k);
902 #endif
903  } else {
904  gemm_microkernel_edge(mr_actual, nr_actual, kb, A_tile, K, B_tile, N, C_tile, N, first_k);
905  }
906  }
907  }
908  }
909 }
910 
911 // Flag to track if we've set optimal thread count (only used in native backend)
912 static int g_threads_initialized = 0;
913 
914 // Set optimal thread count for GEMM (physical cores only, no hyperthreading)
915 static void gemm_init_threads(void) {
916  if (g_threads_initialized) return;
917 
918 #ifdef _OPENMP
919  const CPUInfo* cpu = get_cpu_info();
920  int physical_cores = cpu->num_cores;
921 
922  // Only use physical cores - hyperthreading hurts compute-bound GEMM
923  if (physical_cores > 0) {
924  int current_max = omp_get_max_threads();
925  // Only reduce if we have more threads than physical cores
926  if (current_max > physical_cores) {
927  omp_set_num_threads(physical_cores);
928  }
929  }
930 #endif
932 }
933 
935  const float *A,
936  const float *B,
937  float *C,
938  int M, int N, int K
939 )
940 {
941  const int mr = MR;
942  const int nr = NR;
943 
944  // Use sequential version for small matrices to avoid OpenMP overhead
945  // Threshold tuned for typical 4-8 core systems
946  if ((size_t)M * N * K <= 512ULL * 512 * 512) {
947  gemm_microkernel_sequential(A, B, C, M, N, K);
948  return;
949  }
950 
951  // Initialize thread count to physical cores (once)
953 
954  // Zero output first
955  #pragma omp parallel for schedule(static)
956  for (int i = 0; i < M; i++) {
957  memset(&C[i * N], 0, N * sizeof(float));
958  }
959 
960  // Block over K (outermost - for accumulation across all threads)
961  for (int k0 = 0; k0 < K; k0 += KC) {
962  int kb = (k0 + KC <= K) ? KC : (K - k0);
963  int first_k = (k0 == 0);
964 
965  // Parallelize over M rows - each thread gets a chunk of M
966  // This gives better cache locality than tile-level parallelism
967  #pragma omp parallel for schedule(static)
968  for (int m0 = 0; m0 < M; m0 += mr) {
969  int mr_actual = (m0 + mr <= M) ? mr : (M - m0);
970 
971  // Each thread processes all N tiles for its M rows
972  for (int n0 = 0; n0 < N; n0 += nr) {
973  int nr_actual = (n0 + nr <= N) ? nr : (N - n0);
974 
975  const float *A_tile = &A[m0 * K + k0];
976  const float *B_tile = &B[k0 * N + n0];
977  float *C_tile = &C[m0 * N + n0];
978 
979  if (mr_actual == mr && nr_actual == nr) {
980 #if defined(__AVX512F__)
981  gemm_microkernel_6x32_avx512(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
982 #elif defined(__FMA__)
983  gemm_microkernel_6x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
984 #elif defined(__AVX__)
985  gemm_microkernel_4x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
986 #else
987  gemm_microkernel_edge(mr_actual, nr_actual, kb, A_tile, K, B_tile, N, C_tile, N, first_k);
988 #endif
989  } else {
990  gemm_microkernel_edge(mr_actual, nr_actual, kb, A_tile, K, B_tile, N, C_tile, N, first_k);
991  }
992  }
993  }
994  }
995 }
996 
997 // =============================================================================
998 // B-transposed GEMM: C[M,N] = A[M,K] @ B[N,K].T
999 // =============================================================================
1000 
1001 #if defined(__AVX512F__)
1002 static inline void gemm_microkernel_6x32_bt_avx512(
1003  int K,
1004  const float * __restrict__ A, int lda,
1005  const float * __restrict__ B, int ldb, // B is [N, K] transposed
1006  float * __restrict__ C, int ldc,
1007  int first_k
1008 )
1009 {
1010  // For B transposed, we need different access pattern
1011  // C[i,j] = sum_k A[i,k] * B[j,k]
1012 
1013  if (first_k) {
1014  for (int i = 0; i < 6; i++) {
1015  for (int j = 0; j < 32; j++) {
1016  C[i * ldc + j] = 0.0f;
1017  }
1018  }
1019  }
1020 
1021  // Process K in chunks of 16 for SIMD
1022  int k = 0;
1023  for (; k <= K - 16; k += 16) {
1024  // Load A[0:6, k:k+16] - 6 rows
1025  __m512 a0 = _mm512_loadu_ps(&A[0 * lda + k]);
1026  __m512 a1 = _mm512_loadu_ps(&A[1 * lda + k]);
1027  __m512 a2 = _mm512_loadu_ps(&A[2 * lda + k]);
1028  __m512 a3 = _mm512_loadu_ps(&A[3 * lda + k]);
1029  __m512 a4 = _mm512_loadu_ps(&A[4 * lda + k]);
1030  __m512 a5 = _mm512_loadu_ps(&A[5 * lda + k]);
1031 
1032  // For each column j of C (row j of B)
1033  for (int j = 0; j < 32; j++) {
1034  __m512 b = _mm512_loadu_ps(&B[j * ldb + k]);
1035 
1036  // Compute dot products using reduction
1037  C[0 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a0, b));
1038  C[1 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a1, b));
1039  C[2 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a2, b));
1040  C[3 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a3, b));
1041  C[4 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a4, b));
1042  C[5 * ldc + j] += _mm512_reduce_add_ps(_mm512_mul_ps(a5, b));
1043  }
1044  }
1045 
1046  // Handle remaining K
1047  for (; k < K; k++) {
1048  for (int i = 0; i < 6; i++) {
1049  float a = A[i * lda + k];
1050  for (int j = 0; j < 32; j++) {
1051  C[i * ldc + j] += a * B[j * ldb + k];
1052  }
1053  }
1054  }
1055 }
1056 #endif
1057 
1059  const float *A,
1060  const float *B,
1061  float *C,
1062  int M, int N, int K
1063 )
1064 {
1065  // Zero output first
1066  #pragma omp parallel for schedule(static)
1067  for (int i = 0; i < M; i++) {
1068  memset(&C[i * N], 0, N * sizeof(float));
1069  }
1070 
1071  const int mr = MR;
1072  const int nr = NR;
1073 
1074  #pragma omp parallel for schedule(dynamic) collapse(2)
1075  for (int m0 = 0; m0 < M; m0 += MC) {
1076  for (int n0 = 0; n0 < N; n0 += NC) {
1077  int mb = (m0 + MC <= M) ? MC : (M - m0);
1078  int nb = (n0 + NC <= N) ? NC : (N - n0);
1079 
1080  for (int k0 = 0; k0 < K; k0 += KC) {
1081  int kb = (k0 + KC <= K) ? KC : (K - k0);
1082  int first_k = (k0 == 0);
1083 
1084  for (int m1 = 0; m1 < mb; m1 += mr) {
1085  int mr_actual = (m1 + mr <= mb) ? mr : (mb - m1);
1086 
1087  for (int n1 = 0; n1 < nb; n1 += nr) {
1088  int nr_actual = (n1 + nr <= nb) ? nr : (nb - n1);
1089 
1090  const float *A_tile = &A[(m0 + m1) * K + k0];
1091  const float *B_tile = &B[(n0 + n1) * K + k0];
1092  float *C_tile = &C[(m0 + m1) * N + (n0 + n1)];
1093 
1094  if (mr_actual == mr && nr_actual == nr) {
1095 #if defined(__AVX512F__)
1096  gemm_microkernel_6x32_bt_avx512(kb, A_tile, K, B_tile, K, C_tile, N, first_k);
1097 #else
1098  // Scalar fallback for B-transposed
1099  for (int i = 0; i < mr; i++) {
1100  for (int j = 0; j < nr; j++) {
1101  float sum = first_k ? 0.0f : C_tile[i * N + j];
1102  for (int kk = 0; kk < kb; kk++) {
1103  sum += A_tile[i * K + kk] * B_tile[j * K + kk];
1104  }
1105  C_tile[i * N + j] = sum;
1106  }
1107  }
1108 #endif
1109  } else {
1110  // Edge case
1111  for (int i = 0; i < mr_actual; i++) {
1112  for (int j = 0; j < nr_actual; j++) {
1113  float sum = first_k ? 0.0f : C_tile[i * N + j];
1114  for (int kk = 0; kk < kb; kk++) {
1115  sum += A_tile[i * K + kk] * B_tile[j * K + kk];
1116  }
1117  C_tile[i * N + j] = sum;
1118  }
1119  }
1120  }
1121  }
1122  }
1123  }
1124  }
1125  }
1126 }
1127 
1128 // =============================================================================
1129 // Public API
1130 // =============================================================================
1131 
1132 #define PACK_THRESHOLD 256 // Use packing for matrices >= 256
1133 
1135  const float *A,
1136  const float *B,
1137  float *C,
1138  int M, int N, int K,
1139  int B_transposed
1140 )
1141 {
1142  if (B_transposed) {
1143  gemm_microkernel_blocked_bt(A, B, C, M, N, K);
1144  } else {
1145  // Use packed version for large matrices
1146  if (M >= PACK_THRESHOLD && N >= PACK_THRESHOLD && K >= PACK_THRESHOLD) {
1147  gemm_microkernel_packed(A, B, C, M, N, K);
1148  } else {
1149  gemm_microkernel_blocked(A, B, C, M, N, K);
1150  }
1151  }
1152 }
1153 
1154 #endif // !USE_MKL && !USE_ONEDNN (Native backend)
1155 
1156 // =============================================================================
1157 // Query which backend is in use
1158 // =============================================================================
1159 
1160 const char* gemm_get_backend(void) {
1161  return GEMM_BACKEND;
1162 }
const CPUInfo * get_cpu_info(void)
Definition: cpu_features.c:377
const char * gemm_get_backend(void)
static void gemm_microkernel_edge(int m, int n, int K, const float *A, int lda, const float *B, int ldb, float *C, int ldc, int first_k)
#define GEMM_BACKEND
#define NC
static void gemm_microkernel_sequential(const float *A, const float *B, float *C, int M, int N, int K)
static void gemm_init_threads(void)
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
static void pack_b_panel(const float *B, int ldb, float *Bp, int kc, int nc, int nr)
#define MC
#define KC
#define PACK_THRESHOLD
static void pack_a_panel(const float *A, int lda, float *Ap, int mc, int kc, int mr)
#define MR
void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K)
static int g_threads_initialized
void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K)
#define NR
void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K)
#define C(color)
Definition: show_config.c:39
int num_cores
Definition: cpu_features.h:25