GEMM Memory Layout & Offset Calculation

Understanding Weight Storage
This page explains how weight matrices are stored in memory and why the layout matters for performance. Critical for debugging quantization bugs and understanding kernel behavior.

GEMM vs GEMV

GEMM General Matrix Multiply

C = A × B where both A and B are matrices.

  • Use case: Prefill (batch processing multiple tokens)
  • Dimensions: A[M, K] × B[K, N] = C[M, N]
  • Example: 32 tokens × embedding → 32 outputs
// Prefill: M tokens at once
gemm(activations[32, 896], weights[896, 4864], output[32, 4864])

GEMV General Matrix-Vector Multiply

y = A × x where x is a vector (M=1).

  • Use case: Decode (one token at a time)
  • Dimensions: A[1, K] × B[K, N] = C[1, N]
  • Example: 1 token × embedding → 1 output
// Decode: 1 token at a time
gemv(activation[1, 896], weights[896, 4864], output[1, 4864])

Layout Notation: NN, NT, TN, TT

The two-letter suffix indicates how matrices A and B are stored in memory:

Layout Matrix A Matrix B Use Case
gemm_nn Normal [M, K] Normal [K, N] Textbook layout (rarely used)
gemm_nt Normal [M, K] Transposed [N, K] Inference standard
gemm_tn Transposed [K, M] Normal [K, N] Gradient computation
gemm_tt Transposed [K, M] Transposed [N, K] Specialized cases
Why NT for Inference?
With gemm_nt, each output neuron's weights are contiguous in memory. Computing one output reads sequential bytes (stride=1). With gemm_nn, you'd hop across rows (stride=N) causing cache misses.

Memory Layout Visualization

Weight Memory Layout Diagram showing NN vs NT layouts, quantized block structure, and offset calculations

Normal vs Transposed Storage

gemm_nn: W[K][N] - Bad Cache

Weights stored with input dimension as row index.

// Memory layout for 896→4864 linear
W[0][0], W[0][1], W[0][2], ..., W[0][4863]  // Input 0 to all outputs
W[1][0], W[1][1], W[1][2], ..., W[1][4863]  // Input 1 to all outputs
...
W[895][0], W[895][1], ..., W[895][4863]    // Input 895 to all outputs

Problem: To compute output[0], need W[0][0], W[1][0], W[2][0]... - stride of 4864 elements between reads!

gemm_nt: W[N][K] - Good Cache

Weights stored with output dimension as row index.

// Memory layout for 896→4864 linear
W[0][0], W[0][1], W[0][2], ..., W[0][895]   // All inputs → Output 0
W[1][0], W[1][1], W[1][2], ..., W[1][895]   // All inputs → Output 1
...
W[4863][0], W[4863][1], ..., W[4863][895]  // All inputs → Output 4863

Benefit: To compute output[0], read W[0][0], W[0][1], W[0][2]... - sequential memory access!

Quantized Block Structure

Quantized formats store weights in blocks with a scale factor. The block structure affects offset calculations.

Q8_0 Block Example (34 bytes per block)

struct block_q8_0 {
    ggml_fp16_t d;      // 2 bytes: FP16 scale factor
    int8_t qs[32];      // 32 bytes: 32 quantized weights
};  // Total: 34 bytes for 32 weights

// Dequantization: weight[i] = qs[i] * d

Quantization Format Reference

Format Block Size Bytes/Block Bits/Weight Description
Q8_0 32 34 8.5 8-bit quantization with FP16 scale
Q5_0 32 22 5.5 5-bit with packed high bits + scale
Q4_0 32 18 4.5 4-bit quantization with scale
Q4_K 256 144 4.5 K-quant: super-blocks with sub-scales
Q6_K 256 210 6.5 K-quant: higher precision super-blocks
FP32 1 4 32 Full precision (no quantization)
BF16 1 2 16 Brain float (8-bit mantissa)

Offset Calculation

The Critical Formula

For quantized weights, the offset to row n is:

row_offset = n × (K / block_size) × bytes_per_block

// Example: Q8_0 weights, K=896, accessing row 5
blocks_per_row = 896 / 32 = 28 blocks
row_bytes = 28 × 34 = 952 bytes
row_5_offset = 5 × 952 = 4760 bytes

C Implementation

size_t get_weight_row_offset(
    int n,              // output neuron index
    int K,              // input dimension
    int block_size,     // e.g., 32 for Q8_0
    int bytes_per_block // e.g., 34 for Q8_0
) {
    int blocks = K / block_size;
    return (size_t)n * blocks * bytes_per_block;
}

Python Implementation

def get_weight_row_offset(
    n: int,              # output neuron index
    K: int,              # input dimension
    block_size: int,     # e.g., 32 for Q8_0
    bytes_per_block: int # e.g., 34 for Q8_0
) -> int:
    blocks = K // block_size
    return n * blocks * bytes_per_block

Row Bytes by Format (K=896)

Format Calculation Row Bytes
Q8_0 (896/32) × 34 = 28 × 34 952
Q5_0 (896/32) × 22 = 28 × 22 616
Q4_0 (896/32) × 18 = 28 × 18 504
Q4_K (896/256) × 144 = 3.5 × 144 504*
Q6_K (896/256) × 210 = 3.5 × 210 735*
FP32 896 × 4 3584

* K-quants require K aligned to 256; 896 = 3.5 super-blocks (handled with partial blocks)

v7 Training: IR3 Memory + Parallel Execution Plan

Important for v7
IR2 defines math and gradient topology. IR3 defines memory ownership and thread dispatch. Codegen should emit this plan, not invent it.

IR3 Artifact 1: layout_train.json

  • Single contiguous allocation for training runtime
  • Sections: weights, activations, grad_activations, grad_weights, optimizer states, temporaries
  • Exact offsets and sizes for every tensor
  • Canary ranges for memory diagnostics

IR3 Artifact 2: train_exec_plan.json

  • Per-op dispatch strategy (serial vs threaded)
  • Parallel axis and tile/chunk decisions
  • Reduction policy for backward accumulation
  • Deterministic barriers and phase boundaries

What axis should GEMM split on?

Shape pattern Best split axis Why
M = B*S large M (tokens/rows) Independent output rows, minimal sync overhead
M small (decode-like / tiny microbatch) N (output channels) Row split underutilizes threads when M is tiny
Backward dW with large reduction dimension K (split-K + explicit reduce) Good utilization, but requires deterministic partial reduction
Elementwise kernels Contiguous range chunks Simple partitioning, predictable cache behavior
Attention Head first, then token blocks Natural independence across heads; stable memory access

Minimal execution-plan schema

{
  "schema": "ck.train.exec.v1",
  "runtime": {"threads": 12, "mode": "deterministic"},
  "ops": [
    {
      "op_id": 37,
      "phase": "forward",
      "kernel_id": "gemm_fwd_f32",
      "shape": {"m": 16, "n": 1024, "k": 1024},
      "dispatch": {
        "split_axis": "m",
        "tile_m": 4,
        "tile_n": 128,
        "threads": 12,
        "schedule": "static"
      },
      "reduction": {"type": "none"}
    },
    {
      "op_id": 109,
      "phase": "backward",
      "kernel_id": "gemm_backward_f32",
      "shape": {"m": 1024, "n": 1024, "k": 16},
      "dispatch": {"split_axis": "k", "threads": 12},
      "reduction": {
        "type": "sum",
        "partial_buffer": "tmp.partial.dw.L0.wq",
        "order": "fixed_tree"
      }
    }
  ]
}

Design rule: IR3 owns scheduling decisions, codegen emits fixed calls, kernels stay math-local.

Critical Bug: Row Stride Mismatch

Common Bug: Mismatched Row Stride
If the code generator calculates a different row stride than the kernel expects, tokens beyond the first will read garbage memory, producing NaN/inf outputs.

Correct

// Generator and kernel agree
row_bytes = (K/32) × 34 = 952

Token 0: offset 0
Token 1: offset 952     ✓ Kernel reads here
Token 2: offset 1904    ✓ Kernel reads here

Bug (Extra +1)

// Generator adds spurious +1
row_bytes = (K/32 + 1) × 34 = 986

Token 0: offset 0
Token 1: offset 986     ✗ Kernel expects 952!
Token 2: offset 1972    ✗ Kernel expects 1904!

Symptom: First token works fine, but all subsequent tokens produce NaN or garbage. The validation system in codegen catches this by verifying dimensions align to block sizes.

Kernel Naming Convention

C-Kernel-Engine kernel names follow this pattern:

gemm_nt_q5_0_q8_0
│    │  │     │
│    │  │     └─ Activation quantization (Q8_0)
│    │  └─────── Weight quantization (Q5_0)
│    └────────── Layout: A=Normal, B=Transposed
└─────────────── Operation type (GEMM)

Common Kernel Examples

Kernel Weights Activations Use Case
gemm_nt_q5_0 Q5_0 FP32 Standard decode (single token)
gemm_nt_q5_0_q8_0 Q5_0 Q8_0 INT8 batch prefill (faster)
gemm_nt_q4_k Q4_K FP32 K-quant decode
gemm_nt_q6_k Q6_K FP32 Higher precision K-quant
gemm_blocked_serial_bf16 BF16 BF16 Brain float training

Related Documentation

Quantization Guide

Deep dive into all quantization formats and their trade-offs.

Read More →

GEMM Optimization

Performance tuning and SIMD optimization for GEMM kernels.

Read More →

Codegen Pipeline

How the IR system generates optimized C code.

Read More →
Image
100% | |
Scroll to zoom | Drag to pan | W/H to fit | 0 to reset | ESC to close