v7 Cross-Entropy Parity: Why Small Math Choices Matter

This page explains the v7 cross-entropy (CE) kernel in detail: the p - one_hot gradient, PyTorch reduction semantics, long-horizon drift behavior, and the concrete changes that moved CE from fragile to production-grade for the v7 training harness.

Executive Summary
CE looked simple, but long-horizon behavior was sensitive to tiny semantic and numeric details. Matching PyTorch on reduction/ignore behavior and using stable log-sum-exp loss math removed the drift trigger in the 850-step v7 repro.
v7 cross-entropy parity map from math derivation to production gates

The Core Gradient: Why It Is p - one_hot

For one token with logits z and target class y:

\[ p_i = \frac{e^{z_i}}{\sum_j e^{z_j}}, \quad \mathcal{L} = -\log(p_y) \]

\[ \frac{\partial \mathcal{L}}{\partial z_i} = p_i - \mathbf{1}[i = y] \]

So each row gradient is:

\[ \nabla_{z}\mathcal{L} = p - \text{one\_hot}(y) \]

For mean reduction over N_valid targets:

\[ \nabla_{z}\mathcal{L}_{\text{mean}} = \frac{p - \text{one\_hot}(y)}{N_{\text{valid}}} \]

Why This Matters

The gradient formula itself is standard. Most long-horizon mismatch comes from how we compute softmax/log-loss and which denominator we use for reduction, not from the symbolic derivative.

What Can Go Wrong in Practice

Subtle CE Choice Short-Run Effect Long-Run Effect
-log(prob + eps) clamp for loss Looks stable and safe Can cap loss and bias scalar trajectory under extreme logits
Different reduction denominator Tiny scalar mismatch Persistent optimizer state drift (AdamW moments integrate it)
Ignore-index mismatch Only appears on masked rows Periodic but cumulative parameter divergence
Mixed fp32/double operation ordering differences Often below tolerance early Can cross tolerance threshold at later steps
Missing grad-accum window averaging (/K) before optimizer step Looks like CE drift because loss starts diverging later Effective update scale inflates by ~K; can dominate late-horizon instability

CE /N vs Accum /K

This page focuses on CE internals. But in production drift triage, also verify accumulation-window scaling before AdamW. See v7-grad-accum-windows.html for the full N vs K model and CPU effective-batch examples.

PyTorch CE Semantics We Matched

Primary source files in local PyTorch checkout:

CK v7 CE Variants (Current)

Variant Function Behavior Purpose
Default CE softmax_cross_entropy_loss Legacy all-valid fast path; PyTorch-style fallback when ignore/invalid targets appear Stable baseline + semantic correctness where needed
PyTorch-reference CE softmax_cross_entropy_loss_ptref Always strict index-target mean semantics, strict math path Reference-style parity experiments
Shared strict index-target impl softmax_cross_entropy_loss_index_mean_impl ignore_index=-100, N_valid denominator, all-ignored NaN Semantics parity with PyTorch index-target CE

Implementation: src/kernels/loss_kernels.c

Change Log: What We Improved

  1. Switched CE loss calculation to stable log-sum-exp form to avoid probability-clamp artifacts.
  2. Added explicit PyTorch-style index-target mean reduction semantics with ignore_index=-100.
  3. Added a strict reference CE entrypoint (softmax_cross_entropy_loss_ptref) for parity A/B.
  4. Added guarded default routing to preserve all-valid long-horizon stability while handling ignore/invalid targets correctly.
  5. Expanded CE unit tests to include mixed-ignore and all-ignored behavior checks.

Production Evidence (v7 Harness)

Run Status Key Drift Metrics
/tmp/v7_ce_semantics_c_850.json (older baseline) FAIL first_param_fail_step=117, first_loss_fail_step=64
/tmp/v7_ce_semantics_c_850_postfix.json PASS first_*_fail_step=None, max_loss_abs_diff=9.54e-07
/tmp/v7_ce_semantics_c_ptref_850_postfix.json PASS first_*_fail_step=None, strict CE path also stable
/tmp/v7_ce_semantics_torch_850_postfix.json PASS Control run with PyTorch CE also passes same horizon

Additional training gates run clean in this validation pass:

What We Still Do Not Match (Yet)

Known Scope Gaps vs Full PyTorch CE API

How To Reproduce and Validate

# CE unit tests (includes ignore_index and all-ignored semantics)
source .venv/bin/activate
python unittest/test_cross_entropy.py

# Long-horizon parity stress (850 steps)
source .venv/bin/activate
CK_NUM_THREADS=8 python version/v7/scripts/train_parity_epochs_v7.py \
  --epochs 8 --seq-len 8 --total-tokens 131072 --grad-accum 8 \
  --optimizer adamw --lr 1e-3 --seed 42 --max-steps 850 \
  --ck-rmsnorm-backend c --ck-swiglu-backend c --ck-loss-backend c \
  --json-out /tmp/v7_ce_semantics_c_850_postfix.json

# A/B against strict CE reference path
source .venv/bin/activate
CK_NUM_THREADS=8 python version/v7/scripts/train_parity_epochs_v7.py \
  --epochs 8 --seq-len 8 --total-tokens 131072 --grad-accum 8 \
  --optimizer adamw --lr 1e-3 --seed 42 --max-steps 850 \
  --ck-rmsnorm-backend c --ck-swiglu-backend c --ck-loss-backend c_ptref \
  --json-out /tmp/v7_ce_semantics_c_ptref_850_postfix.json

Bottom Line

CE is not just a textbook formula in production. Matching PyTorch semantics and numeric behavior at the reduction boundary is enough to decide whether long-horizon training stays within tolerance or diverges. In v7, those subtle CE corrections were necessary to make backprop robust.

Image
100% | |
Scroll to zoom | Drag to pan | W/H to fit | 0 to reset | ESC to close