GEMM Optimization Deep Dive
How we built a high-performance matrix multiplication kernel that beats Intel MKL, inspired by oneDNN, BLIS, and decades of HPC research.
1.44x Faster than PyTorch/MKLStanding on the Shoulders of Giants
Our GEMM implementation draws from decades of high-performance computing research and industry-leading libraries:
The 8x8 Microkernel Architecture
The heart of our GEMM is an 8x8 microkernel that keeps all 64 accumulator values in AVX registers throughout the entire K-loop. This is the same strategy used by oneDNN and BLIS.
Matrix Packing for Cache Efficiency
For large matrices, we pack A and B into contiguous memory layouts that maximize cache line utilization. This is the key technique that allowed us to beat MKL.
Three-Level Cache Blocking
We tile the computation to fit each level of the memory hierarchy, minimizing data movement between DRAM and CPU.
Our Optimization Journey
Performance Results
| Matrix Size | PyTorch/MKL | Our Microkernel | Result |
|---|---|---|---|
| 32 x 32 x 32 | 5.17 GFLOPS | 10.19 GFLOPS | 2.0x FASTER |
| 64 x 64 x 64 | 14.03 GFLOPS | 22.94 GFLOPS | 1.6x FASTER |
| 128 x 128 x 128 | 19.75 GFLOPS | 27.75 GFLOPS | 1.4x FASTER |
| 256 x 256 x 256 | 22.41 GFLOPS | 32.76 GFLOPS | 1.5x FASTER |
| 512 x 512 x 512 | 22.50 GFLOPS | 24.22 GFLOPS | 1.1x FASTER |
| 1024 x 1024 x 1024 | 23.00 GFLOPS | 31.48 GFLOPS | 1.4x FASTER |
v7 Training Threadpool Dispatch Playbook
Inference v6.6 and training v7 should both use lowered execution plans. For v7 training, IR2 defines gradient math, while IR3 must define parallel execution and reduction ownership.
IR1
Forward op graph from template + manifest (what to compute).
IR2
Backward synthesis and explicit gradient fanout/fanin accumulation (chain rule routing).
IR3 / Exec Plan
Memory layout + dispatch policy (split axis, tiles, threads, reduction order, barriers).
Dispatch policy by workload shape
| Workload | Preferred split | Reason |
|---|---|---|
M = B*S large GEMM |
split M |
Best cache locality and zero reduction overhead for independent rows. |
Tiny-M GEMM (decode-like) |
split N |
Improves core utilization when row-parallel work is too small. |
Backward dW GEMM |
split K + explicit partial reduce |
Good parallelism, but requires deterministic reduction tree. |
| Elementwise ops | Contiguous range chunks | Simple scheduling and stable memory throughput. |
| Attention | Head split, then token block split | Natural independence across heads, then balanced token work. |
Execution-plan JSON consumed by codegen
{
"schema": "ck.train.exec.v1",
"runtime": {"threads": 12, "simd": "avx2", "mode": "deterministic"},
"ops": [
{
"op_id": 37,
"phase": "forward",
"kernel_id": "gemm_fwd_f32",
"shape": {"m": 16, "n": 1024, "k": 1024},
"dispatch": {"split_axis": "m", "tile_m": 4, "tile_n": 128, "threads": 12}
},
{
"op_id": 109,
"phase": "backward",
"kernel_id": "gemm_backward_f32",
"shape": {"m": 1024, "n": 1024, "k": 16},
"dispatch": {"split_axis": "k", "threads": 12},
"reduction": {"type": "sum", "order": "fixed_tree", "target": "grad.weight.layer.0.wq"}
}
],
"barriers": [{"after_op": 109, "reason": "grad_accum_boundary"}]
}
Rule for maintainability
Codegen should remain dumb. It should emit calls from train_exec_plan.json directly, not infer split/reduction policy ad hoc. This keeps parity, determinism, and performance tuning auditable.
Using the Microkernel
#include "ckernel_engine.h"
// Basic usage - automatically selects best implementation
float A[M * K], B[K * N], C[M * N];
gemm_microkernel(A, B, C, M, N, K, 0); // B not transposed
// For neural network weights (B is [N, K] transposed)
gemm_microkernel(A, B, C, M, N, K, 1); // B transposed
// Direct packed version for large matrices
gemm_microkernel_packed(A, B, C, M, N, K);
Source Files
gemm_microkernel.c
Main microkernel implementation with 8x8 register blocking, matrix packing, and cache blocking.
src/kernels/gemm_microkernel.c
gemm_fused_kernels.c
Fused GEMM operations: GEMM+ReLU, GEMM+GELU, GEMM+SiLU, and the dual-GEMM SwiGLU.
src/kernels/gemm_fused_kernels.c
test_gemm_microkernel.py
Unit test with accuracy verification and performance benchmarks vs PyTorch.
unittest/test_gemm_microkernel.py
References & Further Reading
BLIS: A Framework for Rapidly Instantiating BLAS Functionality
The foundational paper on microkernel-based GEMM design.
Field G. Van Zee, Robert A. van de Geijn
Anatomy of High-Performance Matrix Multiplication
Classic paper explaining cache blocking and register tiling.
Kazushige Goto, Robert A. van de Geijn
oneDNN Developer Guide
Intel's deep learning library with state-of-the-art GEMM kernels.
oneapi-src/oneDNN
How to Optimize GEMM
Practical tutorial on GEMM optimization techniques.
flame.cs.utexas.edu/~flame/web/