Gated DeltaNet: Recurrent Attention Deep Dive
src/kernels/deltanet_kernels.c — FP32 Gated DeltaNet with REF, AVX, AVX2, and AVX-512 implementations.
Matches the single-token recurrent path used by Qwen3.5 / qwen3next in llama.cpp.
How One DeltaNet Step Works
The diagram below shows the complete per-head recurrent update that runs at every single-token decode step. Unlike standard attention (which recomputes over the full KV-cache), DeltaNet maintains a fixed-size state matrix S that gets updated in-place.
Mathematical Equations
Per head (h = 0 … num_heads-1):
q̂ = L2_norm(q) / √state_dim ← scaled unit query
k̂ = L2_norm(k) ← unit key
β_s = σ(β) = 1/(1+e^(-β)) ← write gate (0,1)
gate = exp(g) ← decay gate (0,∞)
S_decay = gate · S_prev ← element-wise forget
kv_mem = S_decayᵀ · k̂ ← what memory recalls for k̂
δ = β_s · (v − kv_mem) ← correction: new info minus old
S_new = S_decay + outer(k̂, δ) ← rank-1 write to memory
out = S_newᵀ · q̂ ← read from updated memory
Key insight: The delta rule computes the error between what the model wants to store (v) and what the state already recalls (kv_mem). Only the correction is written, gated by β_s. This makes DeltaNet a learned associative memory with selective forgetting.
Memory Layout
All arrays are flat row-major FP32 buffers. The head dimension is outermost for cache-friendly per-head iteration.
ISA Dispatch & SIMD Tiers
The public entry point gated_deltanet_autoregressive_forward() selects the best compiled implementation at link time. When strict parity is enabled (via ck_strict_parity_enabled()), it always falls back to the scalar reference.
SIMD Optimization Strategy
Reference (_ref)
- Pure scalar C — no intrinsics
- Triple-nested loops: head → row → col
- Used for parity testing against llama.cpp
- Bit-exact reference for correctness
AVX (_avx)
- 256-bit SIMD — 8 floats per instruction
- Pre-normalizes q̂ / k̂ before the state sweep
- Single row walk per state-matrix pass
- Scalar tail loop for non-multiple-of-8 dims
AVX2 + FMA (_avx2)
- Same 256-bit width, adds FMA:
a×b+cin 1 cycle - 2-row unroll: processes rows in pairs to halve loop overhead
- Uses
_mm256_fmadd_pswhere available - Falls back to mul+add without FMA flag
AVX-512 (_avx512)
- 512-bit SIMD — 16 floats per instruction
- Native FMA via
_mm512_fmadd_ps - Uses
_mm512_reduce_add_psfor hsum - Largest vector width = fewest iterations
DeltaNet vs Standard Attention
| Property | Standard Attention | Gated DeltaNet |
|---|---|---|
| Memory per head | O(T × d) — grows with context | O(d²) — fixed |
| Per-token cost | O(T × d) — scans full KV-cache | O(d²) — constant |
| State | KV-cache (append-only) | Matrix S (overwrite via delta rule) |
| Forgetting | None (or windowed) | Exponential decay via exp(g) |
| Write mechanism | Append new K,V rows | Rank-1 correction: outer(k̂, δ) |
| Best for | Precise long-range recall | Streaming / very long contexts |
CK-Engine Kernel Rules
Like all CK kernels, the DeltaNet implementation follows strict rules:
🚫 No malloc / free
Memory comes via bump allocator. All pointers are passed in. Stack arrays (CK_DELTANET_MAX_STACK_DIM = 4096) are used for temporaries.
🚫 No OpenMP
Parallelization happens at the orchestrator/codegen layer. Kernels are single-threaded, deterministic units.
✅ Pure computation
No side effects, no global state. Given the same inputs, produces identical outputs. Essential for parity testing.
✅ Defined API contract
Every kernel declares: inputs, outputs, workspace requirements, and memory layouts. The dispatcher selects ISA at compile time.
📊 Back to Kernel Catalog
For the full list of CK-Engine kernels (GEMM, RoPE, Softmax, Loss, etc.), see: