GEMM Memory Layout & Offset Calculation
This page explains how weight matrices are stored in memory and why the layout matters for performance. Critical for debugging quantization bugs and understanding kernel behavior.
GEMM vs GEMV
GEMM General Matrix Multiply
C = A × B where both A and B are matrices.
- Use case: Prefill (batch processing multiple tokens)
- Dimensions: A[M, K] × B[K, N] = C[M, N]
- Example: 32 tokens × embedding → 32 outputs
// Prefill: M tokens at once
gemm(activations[32, 896], weights[896, 4864], output[32, 4864])
GEMV General Matrix-Vector Multiply
y = A × x where x is a vector (M=1).
- Use case: Decode (one token at a time)
- Dimensions: A[1, K] × B[K, N] = C[1, N]
- Example: 1 token × embedding → 1 output
// Decode: 1 token at a time
gemv(activation[1, 896], weights[896, 4864], output[1, 4864])
Layout Notation: NN, NT, TN, TT
The two-letter suffix indicates how matrices A and B are stored in memory:
| Layout | Matrix A | Matrix B | Use Case |
|---|---|---|---|
gemm_nn |
Normal [M, K] | Normal [K, N] | Textbook layout (rarely used) |
gemm_nt |
Normal [M, K] | Transposed [N, K] | Inference standard |
gemm_tn |
Transposed [K, M] | Normal [K, N] | Gradient computation |
gemm_tt |
Transposed [K, M] | Transposed [N, K] | Specialized cases |
With
gemm_nt, each output neuron's weights are contiguous in memory. Computing one output reads sequential bytes (stride=1). With gemm_nn, you'd hop across rows (stride=N) causing cache misses.
Memory Layout Visualization
Normal vs Transposed Storage
gemm_nn: W[K][N] - Bad Cache
Weights stored with input dimension as row index.
// Memory layout for 896→4864 linear
W[0][0], W[0][1], W[0][2], ..., W[0][4863] // Input 0 to all outputs
W[1][0], W[1][1], W[1][2], ..., W[1][4863] // Input 1 to all outputs
...
W[895][0], W[895][1], ..., W[895][4863] // Input 895 to all outputs
Problem: To compute output[0], need W[0][0], W[1][0], W[2][0]... - stride of 4864 elements between reads!
gemm_nt: W[N][K] - Good Cache
Weights stored with output dimension as row index.
// Memory layout for 896→4864 linear
W[0][0], W[0][1], W[0][2], ..., W[0][895] // All inputs → Output 0
W[1][0], W[1][1], W[1][2], ..., W[1][895] // All inputs → Output 1
...
W[4863][0], W[4863][1], ..., W[4863][895] // All inputs → Output 4863
Benefit: To compute output[0], read W[0][0], W[0][1], W[0][2]... - sequential memory access!
Quantized Block Structure
Quantized formats store weights in blocks with a scale factor. The block structure affects offset calculations.
Q8_0 Block Example (34 bytes per block)
struct block_q8_0 {
ggml_fp16_t d; // 2 bytes: FP16 scale factor
int8_t qs[32]; // 32 bytes: 32 quantized weights
}; // Total: 34 bytes for 32 weights
// Dequantization: weight[i] = qs[i] * d
Quantization Format Reference
| Format | Block Size | Bytes/Block | Bits/Weight | Description |
|---|---|---|---|---|
Q8_0 |
32 | 34 | 8.5 | 8-bit quantization with FP16 scale |
Q5_0 |
32 | 22 | 5.5 | 5-bit with packed high bits + scale |
Q4_0 |
32 | 18 | 4.5 | 4-bit quantization with scale |
Q4_K |
256 | 144 | 4.5 | K-quant: super-blocks with sub-scales |
Q6_K |
256 | 210 | 6.5 | K-quant: higher precision super-blocks |
FP32 |
1 | 4 | 32 | Full precision (no quantization) |
BF16 |
1 | 2 | 16 | Brain float (8-bit mantissa) |
Offset Calculation
The Critical Formula
For quantized weights, the offset to row n is:
row_offset = n × (K / block_size) × bytes_per_block
// Example: Q8_0 weights, K=896, accessing row 5
blocks_per_row = 896 / 32 = 28 blocks
row_bytes = 28 × 34 = 952 bytes
row_5_offset = 5 × 952 = 4760 bytes
C Implementation
size_t get_weight_row_offset(
int n, // output neuron index
int K, // input dimension
int block_size, // e.g., 32 for Q8_0
int bytes_per_block // e.g., 34 for Q8_0
) {
int blocks = K / block_size;
return (size_t)n * blocks * bytes_per_block;
}
Python Implementation
def get_weight_row_offset(
n: int, # output neuron index
K: int, # input dimension
block_size: int, # e.g., 32 for Q8_0
bytes_per_block: int # e.g., 34 for Q8_0
) -> int:
blocks = K // block_size
return n * blocks * bytes_per_block
Row Bytes by Format (K=896)
| Format | Calculation | Row Bytes |
|---|---|---|
Q8_0 |
(896/32) × 34 = 28 × 34 | 952 |
Q5_0 |
(896/32) × 22 = 28 × 22 | 616 |
Q4_0 |
(896/32) × 18 = 28 × 18 | 504 |
Q4_K |
(896/256) × 144 = 3.5 × 144 | 504* |
Q6_K |
(896/256) × 210 = 3.5 × 210 | 735* |
FP32 |
896 × 4 | 3584 |
* K-quants require K aligned to 256; 896 = 3.5 super-blocks (handled with partial blocks)
v7 Training: IR3 Memory + Parallel Execution Plan
IR2 defines math and gradient topology. IR3 defines memory ownership and thread dispatch. Codegen should emit this plan, not invent it.
IR3 Artifact 1: layout_train.json
- Single contiguous allocation for training runtime
- Sections: weights, activations, grad_activations, grad_weights, optimizer states, temporaries
- Exact offsets and sizes for every tensor
- Canary ranges for memory diagnostics
IR3 Artifact 2: train_exec_plan.json
- Per-op dispatch strategy (serial vs threaded)
- Parallel axis and tile/chunk decisions
- Reduction policy for backward accumulation
- Deterministic barriers and phase boundaries
What axis should GEMM split on?
| Shape pattern | Best split axis | Why |
|---|---|---|
M = B*S large |
M (tokens/rows) |
Independent output rows, minimal sync overhead |
M small (decode-like / tiny microbatch) |
N (output channels) |
Row split underutilizes threads when M is tiny |
Backward dW with large reduction dimension |
K (split-K + explicit reduce) |
Good utilization, but requires deterministic partial reduction |
| Elementwise kernels | Contiguous range chunks | Simple partitioning, predictable cache behavior |
| Attention | Head first, then token blocks | Natural independence across heads; stable memory access |
Minimal execution-plan schema
{
"schema": "ck.train.exec.v1",
"runtime": {"threads": 12, "mode": "deterministic"},
"ops": [
{
"op_id": 37,
"phase": "forward",
"kernel_id": "gemm_fwd_f32",
"shape": {"m": 16, "n": 1024, "k": 1024},
"dispatch": {
"split_axis": "m",
"tile_m": 4,
"tile_n": 128,
"threads": 12,
"schedule": "static"
},
"reduction": {"type": "none"}
},
{
"op_id": 109,
"phase": "backward",
"kernel_id": "gemm_backward_f32",
"shape": {"m": 1024, "n": 1024, "k": 16},
"dispatch": {"split_axis": "k", "threads": 12},
"reduction": {
"type": "sum",
"partial_buffer": "tmp.partial.dw.L0.wq",
"order": "fixed_tree"
}
}
]
}
Design rule: IR3 owns scheduling decisions, codegen emits fixed calls, kernels stay math-local.
Critical Bug: Row Stride Mismatch
If the code generator calculates a different row stride than the kernel expects, tokens beyond the first will read garbage memory, producing NaN/inf outputs.
Correct
// Generator and kernel agree
row_bytes = (K/32) × 34 = 952
Token 0: offset 0
Token 1: offset 952 ✓ Kernel reads here
Token 2: offset 1904 ✓ Kernel reads here
Bug (Extra +1)
// Generator adds spurious +1
row_bytes = (K/32 + 1) × 34 = 986
Token 0: offset 0
Token 1: offset 986 ✗ Kernel expects 952!
Token 2: offset 1972 ✗ Kernel expects 1904!
Symptom: First token works fine, but all subsequent tokens produce NaN or garbage. The validation system in codegen catches this by verifying dimensions align to block sizes.
Kernel Naming Convention
C-Kernel-Engine kernel names follow this pattern:
gemm_nt_q5_0_q8_0
│ │ │ │
│ │ │ └─ Activation quantization (Q8_0)
│ │ └─────── Weight quantization (Q5_0)
│ └────────── Layout: A=Normal, B=Transposed
└─────────────── Operation type (GEMM)
Common Kernel Examples
| Kernel | Weights | Activations | Use Case |
|---|---|---|---|
gemm_nt_q5_0 |
Q5_0 | FP32 | Standard decode (single token) |
gemm_nt_q5_0_q8_0 |
Q5_0 | Q8_0 | INT8 batch prefill (faster) |
gemm_nt_q4_k |
Q4_K | FP32 | K-quant decode |
gemm_nt_q6_k |
Q6_K | FP32 | Higher precision K-quant |
gemm_blocked_serial_bf16 |
BF16 | BF16 | Brain float training |