Gated DeltaNet: Recurrent Attention Deep Dive

Kernel Source
src/kernels/deltanet_kernels.c — FP32 Gated DeltaNet with REF, AVX, AVX2, and AVX-512 implementations. Matches the single-token recurrent path used by Qwen3.5 / qwen3next in llama.cpp.

How One DeltaNet Step Works

The diagram below shows the complete per-head recurrent update that runs at every single-token decode step. Unlike standard attention (which recomputes over the full KV-cache), DeltaNet maintains a fixed-size state matrix S that gets updated in-place.

GATED DELTANET — SINGLE-TOKEN RECURRENT UPDATE INPUTS (per head) q [state_dim] k [state_dim] v [state_dim] β scalar g scalar S_prev [state_dim × state_dim] recurrent memory matrix STEP 1 & 2: NORMALIZE & GATE q̂ = L2norm(q) / √state_dim k̂ = L2norm(k) β_s = σ(β) sigmoid gate gate = exp(g) decay factor STEP 3: STATE DECAY S_decay = gate × S_prev element-wise multiply every cell by exp(g) STEP 4: RETRIEVE & COMPUTE DELTA kv_mem = Sᵀ_decay · k̂ what the state remembers for k̂ δ = β_s · (v − kv_mem) scaled correction vector STEP 5: STATE UPDATE S_new = S_decay + outer(k̂, δ) rank-1 update adds new information to memory STEP 6: READOUT out = S_newᵀ · q̂ [state_dim] output vector S_new → S_prev (next step) Input vectors Scalar gates Compute ops State matrix Output Recurrent feedback Data flow

Mathematical Equations

Per head (h = 0 … num_heads-1):

  q̂   = L2_norm(q) / √state_dim      ← scaled unit query
  k̂   = L2_norm(k)                    ← unit key
  β_s = σ(β)        = 1/(1+e^(-β))    ← write gate  (0,1)
  gate = exp(g)                        ← decay gate  (0,∞)

  S_decay = gate · S_prev              ← element-wise forget
  kv_mem  = S_decayᵀ · k̂              ← what memory recalls for k̂
  δ       = β_s · (v − kv_mem)         ← correction: new info minus old
  S_new   = S_decay + outer(k̂, δ)     ← rank-1 write to memory
  out     = S_newᵀ  · q̂               ← read from updated memory

Key insight: The delta rule computes the error between what the model wants to store (v) and what the state already recalls (kv_mem). Only the correction is written, gated by β_s. This makes DeltaNet a learned associative memory with selective forgetting.

Memory Layout

All arrays are flat row-major FP32 buffers. The head dimension is outermost for cache-friendly per-head iteration.

VECTORS: q, k, v, out [num_heads × state_dim] — stride = state_dim Head 0: [d₀ d₁ … d_{D-1}] Head 1: [d₀ d₁ … d_{D-1}] Head H-1: [d₀ … d_{D-1}] SCALARS: g, β [num_heads] — one value per head h₀ h₁ h_{H-1} STATE: state_in, state_out [num_heads × state_dim × state_dim] — stride = state_dim² H0 H1 H-1 Each: D×D floats (row-major)

ISA Dispatch & SIMD Tiers

The public entry point gated_deltanet_autoregressive_forward() selects the best compiled implementation at link time. When strict parity is enabled (via ck_strict_parity_enabled()), it always falls back to the scalar reference.

gated_deltanet_autoregressive_forward() strict parity? YES _ref (scalar) NO AVX-512? YES _avx512 16 floats/iter · FMA NO AVX2? YES _avx2 8 floats/iter · 2-row unroll NO AVX? _ref (fallback) _avx

SIMD Optimization Strategy

Reference (_ref)

  • Pure scalar C — no intrinsics
  • Triple-nested loops: head → row → col
  • Used for parity testing against llama.cpp
  • Bit-exact reference for correctness

AVX (_avx)

  • 256-bit SIMD — 8 floats per instruction
  • Pre-normalizes q̂ / k̂ before the state sweep
  • Single row walk per state-matrix pass
  • Scalar tail loop for non-multiple-of-8 dims

AVX2 + FMA (_avx2)

  • Same 256-bit width, adds FMA: a×b+c in 1 cycle
  • 2-row unroll: processes rows in pairs to halve loop overhead
  • Uses _mm256_fmadd_ps where available
  • Falls back to mul+add without FMA flag

AVX-512 (_avx512)

  • 512-bit SIMD — 16 floats per instruction
  • Native FMA via _mm512_fmadd_ps
  • Uses _mm512_reduce_add_ps for hsum
  • Largest vector width = fewest iterations

DeltaNet vs Standard Attention

Property Standard Attention Gated DeltaNet
Memory per head O(T × d) — grows with context O(d²) — fixed
Per-token cost O(T × d) — scans full KV-cache O(d²) — constant
State KV-cache (append-only) Matrix S (overwrite via delta rule)
Forgetting None (or windowed) Exponential decay via exp(g)
Write mechanism Append new K,V rows Rank-1 correction: outer(k̂, δ)
Best for Precise long-range recall Streaming / very long contexts

CK-Engine Kernel Rules

Like all CK kernels, the DeltaNet implementation follows strict rules:

🚫 No malloc / free

Memory comes via bump allocator. All pointers are passed in. Stack arrays (CK_DELTANET_MAX_STACK_DIM = 4096) are used for temporaries.

🚫 No OpenMP

Parallelization happens at the orchestrator/codegen layer. Kernels are single-threaded, deterministic units.

✅ Pure computation

No side effects, no global state. Given the same inputs, produces identical outputs. Essential for parity testing.

✅ Defined API contract

Every kernel declares: inputs, outputs, workspace requirements, and memory layouts. The dispatcher selects ISA at compile time.

📊 Back to Kernel Catalog

For the full list of CK-Engine kernels (GEMM, RoPE, Softmax, Loss, etc.), see:

Kernel Catalog

Image
100% | |
Scroll to zoom | Drag to pan | W/H to fit | 0 to reset | ESC to close