GEMM Optimization Deep Dive

How we built a high-performance matrix multiplication kernel that beats Intel MKL, inspired by oneDNN, BLIS, and decades of HPC research.

1.44x Faster than PyTorch/MKL
52.95
Peak GFLOPS
1.44x
vs PyTorch/MKL
4.31x
vs Naive GEMM
8x8
Register Tile

Standing on the Shoulders of Giants

Our GEMM implementation draws from decades of high-performance computing research and industry-leading libraries:

oneDNN Intel Deep Neural Network Library BLIS BLAS-like Library Instantiation Intel MKL Math Kernel Library C-Kernel Engine 1.44x Faster! JIT Compilation Post-ops Fusion Microkernel Design Register Blocking Matrix Packing Cache Optimization

The 8x8 Microkernel Architecture

The heart of our GEMM is an 8x8 microkernel that keeps all 64 accumulator values in AVX registers throughout the entire K-loop. This is the same strategy used by oneDNN and BLIS.

8x8 Register-Blocked Microkernel A [8 x K] row 0 row 1 row 2 row 7 ... col k @ B [K x 8] k 0 1 2 ... 7 = C [8 x 8] in Registers 64 floats YMM Registers: c0 = row 0 of C [8 floats] c1 = row 1 of C [8 floats] c2 = row 2 of C [8 floats] ... c7 = row 7 of C [8 floats] b = B[k, 0:8] [8 floats] a0-a7 = A[0:8, k] broadcast The K-Loop: All 64 Values Stay in Registers! k = 0 b = load B[0, :] a = broadcast A[:, 0] c += a * b k = 1 b = load B[1, :] a = broadcast A[:, 1] c += a * b ... k = K-1 b = load B[K-1, :] a = broadcast A[:, K-1] c += a * b Store C to memory ONCE! Key Insight: C values never leave registers during the entire K-loop = Maximum arithmetic intensity

Matrix Packing for Cache Efficiency

For large matrices, we pack A and B into contiguous memory layouts that maximize cache line utilization. This is the key technique that allowed us to beat MKL.

Matrix Packing: Before vs After Before: Strided Access A Matrix [M x K] Cache misses on every row! pack_a() After: Contiguous Access Packed A [MC x KC panels] Panel 0 Panel 1 Panel 2 Sequential memory access! B Matrix: Column Stride ... pack_b() B Packed: Row-Panel Format k=0: [8 floats contiguous] k=1: [8 floats contiguous] k=2: [8 floats contiguous] Packing converts strided access to sequential access = Full cache line utilization

Three-Level Cache Blocking

We tile the computation to fit each level of the memory hierarchy, minimizing data movement between DRAM and CPU.

Memory Hierarchy Registers ~1 cycle L1 Cache (32KB) ~4 cycles L2 Cache (256KB) ~12 cycles L3 Cache (6-8MB) ~40 cycles DRAM (Main Memory) ~200 cycles Our Blocking Parameters A Panel: MC x KC = 64 x 256 64 × 256 × 4 bytes = 64KB (fits in L2) Reused across all N tiles B Panel: KC x NC = 256 x 256 256 × 256 × 4 bytes = 256KB (fits in L3) Reused across all M tiles Microkernel: MR x NR = 8 x 8 8 × 8 × 4 bytes = 256 bytes (fits in 8 YMM registers) All accumulators in registers! Loop Nest Structure for k0 in range(0, K, KC): for n0 in range(0, N, NC): # pack B panel for m0 in range(0, M, MC): # pack A panel, call microkernel # L3 blocking # L2 blocking # L1 blocking + parallel

Our Optimization Journey

Step 1: Naive Implementation
Triple nested loop with poor cache behavior. ~8 GFLOPS.
Step 2: OpenMP Parallelization
Added parallel for loops. Still cache-unfriendly.
Step 3: AVX Vectorization
8-wide SIMD using 256-bit YMM registers. ~15 GFLOPS.
Step 4: 8x8 Register Blocking
Keep 64 accumulators in registers across K-loop. ~20 GFLOPS.
Step 5: Cache Blocking (MC/NC/KC)
Tile for L1/L2/L3 cache hierarchy. ~25 GFLOPS.
Step 6: Matrix Packing
Pack A and B for contiguous access. ~30 GFLOPS.
Step 7: Software Prefetching + Loop Unrolling
Prefetch next cache line, unroll K by 4. 52.95 GFLOPS peak!

Performance Results

Matrix Size PyTorch/MKL Our Microkernel Result
32 x 32 x 32 5.17 GFLOPS 10.19 GFLOPS 2.0x FASTER
64 x 64 x 64 14.03 GFLOPS 22.94 GFLOPS 1.6x FASTER
128 x 128 x 128 19.75 GFLOPS 27.75 GFLOPS 1.4x FASTER
256 x 256 x 256 22.41 GFLOPS 32.76 GFLOPS 1.5x FASTER
512 x 512 x 512 22.50 GFLOPS 24.22 GFLOPS 1.1x FASTER
1024 x 1024 x 1024 23.00 GFLOPS 31.48 GFLOPS 1.4x FASTER
Our kernel beats MKL at ALL sizes

v7 Training Threadpool Dispatch Playbook

Inference v6.6 and training v7 should both use lowered execution plans. For v7 training, IR2 defines gradient math, while IR3 must define parallel execution and reduction ownership.

IR1

Forward op graph from template + manifest (what to compute).

IR2

Backward synthesis and explicit gradient fanout/fanin accumulation (chain rule routing).

IR3 / Exec Plan

Memory layout + dispatch policy (split axis, tiles, threads, reduction order, barriers).

Dispatch policy by workload shape

Workload Preferred split Reason
M = B*S large GEMM split M Best cache locality and zero reduction overhead for independent rows.
Tiny-M GEMM (decode-like) split N Improves core utilization when row-parallel work is too small.
Backward dW GEMM split K + explicit partial reduce Good parallelism, but requires deterministic reduction tree.
Elementwise ops Contiguous range chunks Simple scheduling and stable memory throughput.
Attention Head split, then token block split Natural independence across heads, then balanced token work.

Execution-plan JSON consumed by codegen

{
  "schema": "ck.train.exec.v1",
  "runtime": {"threads": 12, "simd": "avx2", "mode": "deterministic"},
  "ops": [
    {
      "op_id": 37,
      "phase": "forward",
      "kernel_id": "gemm_fwd_f32",
      "shape": {"m": 16, "n": 1024, "k": 1024},
      "dispatch": {"split_axis": "m", "tile_m": 4, "tile_n": 128, "threads": 12}
    },
    {
      "op_id": 109,
      "phase": "backward",
      "kernel_id": "gemm_backward_f32",
      "shape": {"m": 1024, "n": 1024, "k": 16},
      "dispatch": {"split_axis": "k", "threads": 12},
      "reduction": {"type": "sum", "order": "fixed_tree", "target": "grad.weight.layer.0.wq"}
    }
  ],
  "barriers": [{"after_op": 109, "reason": "grad_accum_boundary"}]
}

Rule for maintainability

Codegen should remain dumb. It should emit calls from train_exec_plan.json directly, not infer split/reduction policy ad hoc. This keeps parity, determinism, and performance tuning auditable.

Using the Microkernel

#include "ckernel_engine.h"

// Basic usage - automatically selects best implementation
float A[M * K], B[K * N], C[M * N];
gemm_microkernel(A, B, C, M, N, K, 0);  // B not transposed

// For neural network weights (B is [N, K] transposed)
gemm_microkernel(A, B, C, M, N, K, 1);  // B transposed

// Direct packed version for large matrices
gemm_microkernel_packed(A, B, C, M, N, K);

Source Files

gemm_microkernel.c

Main microkernel implementation with 8x8 register blocking, matrix packing, and cache blocking.

src/kernels/gemm_microkernel.c

gemm_fused_kernels.c

Fused GEMM operations: GEMM+ReLU, GEMM+GELU, GEMM+SiLU, and the dual-GEMM SwiGLU.

src/kernels/gemm_fused_kernels.c

test_gemm_microkernel.py

Unit test with accuracy verification and performance benchmarks vs PyTorch.

unittest/test_gemm_microkernel.py

References & Further Reading

BLIS: A Framework for Rapidly Instantiating BLAS Functionality

The foundational paper on microkernel-based GEMM design.

Field G. Van Zee, Robert A. van de Geijn

Anatomy of High-Performance Matrix Multiplication

Classic paper explaining cache blocking and register tiling.

Kazushige Goto, Robert A. van de Geijn

oneDNN Developer Guide

Intel's deep learning library with state-of-the-art GEMM kernels.

oneapi-src/oneDNN

How to Optimize GEMM

Practical tutorial on GEMM optimization techniques.

flame.cs.utexas.edu/~flame/web/

Image
100% | |
Scroll to zoom | Drag to pan | W/H to fit | 0 to reset | ESC to close