v7 Cross-Entropy Parity: Why Small Math Choices Matter
This page explains the v7 cross-entropy (CE) kernel in detail: the p - one_hot gradient, PyTorch reduction semantics, long-horizon drift behavior, and the concrete changes that moved CE from fragile to production-grade for the v7 training harness.
CE looked simple, but long-horizon behavior was sensitive to tiny semantic and numeric details. Matching PyTorch on reduction/ignore behavior and using stable log-sum-exp loss math removed the drift trigger in the 850-step v7 repro.
The Core Gradient: Why It Is p - one_hot
For one token with logits z and target class y:
\[ p_i = \frac{e^{z_i}}{\sum_j e^{z_j}}, \quad \mathcal{L} = -\log(p_y) \]
\[ \frac{\partial \mathcal{L}}{\partial z_i} = p_i - \mathbf{1}[i = y] \]
So each row gradient is:
\[ \nabla_{z}\mathcal{L} = p - \text{one\_hot}(y) \]
For mean reduction over N_valid targets:
\[ \nabla_{z}\mathcal{L}_{\text{mean}} = \frac{p - \text{one\_hot}(y)}{N_{\text{valid}}} \]
Why This Matters
The gradient formula itself is standard. Most long-horizon mismatch comes from how we compute softmax/log-loss and which denominator we use for reduction, not from the symbolic derivative.
What Can Go Wrong in Practice
| Subtle CE Choice | Short-Run Effect | Long-Run Effect |
|---|---|---|
-log(prob + eps) clamp for loss |
Looks stable and safe | Can cap loss and bias scalar trajectory under extreme logits |
| Different reduction denominator | Tiny scalar mismatch | Persistent optimizer state drift (AdamW moments integrate it) |
| Ignore-index mismatch | Only appears on masked rows | Periodic but cumulative parameter divergence |
| Mixed fp32/double operation ordering differences | Often below tolerance early | Can cross tolerance threshold at later steps |
Missing grad-accum window averaging (/K) before optimizer step |
Looks like CE drift because loss starts diverging later | Effective update scale inflates by ~K; can dominate late-horizon instability |
CE /N vs Accum /K
This page focuses on CE internals. But in production drift triage, also verify accumulation-window scaling before AdamW. See v7-grad-accum-windows.html for the full N vs K model and CPU effective-batch examples.
PyTorch CE Semantics We Matched
F.cross_entropy(...)routes totorch._C._nn.cross_entropy_loss.- Index-target path is effectively
log_softmax + nll_loss. - Default
ignore_index = -100for index targets. - Mean reduction divides by non-ignored target count (or summed target weights if class weights are used): \(\frac{\sum \ell_t}{N_{\text{valid}}}\).
- All targets ignored with mean reduction returns
NaNloss; gradients are zero.
Primary source files in local PyTorch checkout:
pytorch/torch/nn/functional.py(cross_entropyentrypoint)pytorch/aten/src/ATen/native/native_functions.yaml(cross_entropy_lossdispatch)pytorch/aten/src/ATen/native/LossNLL.cpp(CPU reduction/backward semantics)
CK v7 CE Variants (Current)
| Variant | Function | Behavior | Purpose |
|---|---|---|---|
| Default CE | softmax_cross_entropy_loss |
Legacy all-valid fast path; PyTorch-style fallback when ignore/invalid targets appear | Stable baseline + semantic correctness where needed |
| PyTorch-reference CE | softmax_cross_entropy_loss_ptref |
Always strict index-target mean semantics, strict math path | Reference-style parity experiments |
| Shared strict index-target impl | softmax_cross_entropy_loss_index_mean_impl |
ignore_index=-100, N_valid denominator, all-ignored NaN |
Semantics parity with PyTorch index-target CE |
Implementation: src/kernels/loss_kernels.c
Change Log: What We Improved
- Switched CE loss calculation to stable log-sum-exp form to avoid probability-clamp artifacts.
- Added explicit PyTorch-style index-target mean reduction semantics with
ignore_index=-100. - Added a strict reference CE entrypoint (
softmax_cross_entropy_loss_ptref) for parity A/B. - Added guarded default routing to preserve all-valid long-horizon stability while handling ignore/invalid targets correctly.
- Expanded CE unit tests to include mixed-ignore and all-ignored behavior checks.
Production Evidence (v7 Harness)
| Run | Status | Key Drift Metrics |
|---|---|---|
/tmp/v7_ce_semantics_c_850.json (older baseline) |
FAIL | first_param_fail_step=117, first_loss_fail_step=64 |
/tmp/v7_ce_semantics_c_850_postfix.json |
PASS | first_*_fail_step=None, max_loss_abs_diff=9.54e-07 |
/tmp/v7_ce_semantics_c_ptref_850_postfix.json |
PASS | first_*_fail_step=None, strict CE path also stable |
/tmp/v7_ce_semantics_torch_850_postfix.json |
PASS | Control run with PyTorch CE also passes same horizon |
Additional training gates run clean in this validation pass:
version/v7/.cache/reports/train_parity_drift_smoke_latest.json(PASS)version/v7/.cache/reports/train_parity_realistic_long_horizon_latest.json(PASS)version/v7/.cache/reports/optimizer_parity_latest.json(PASS)version/v7/.cache/reports/fd_gradients_latest.json(PASS)version/v7/.cache/reports/replay_determinism_latest.json(PASS)
What We Still Do Not Match (Yet)
Known Scope Gaps vs Full PyTorch CE API
- Label smoothing path is not implemented in CK CE kernel.
- Per-class weighting parity path is not implemented in CK CE kernel.
- Probability-target CE (dense target distributions) is not supported in CK CE kernel.
- CK fused CE is focused on training mean-reduction path used in v7 harness.
How To Reproduce and Validate
# CE unit tests (includes ignore_index and all-ignored semantics) source .venv/bin/activate python unittest/test_cross_entropy.py # Long-horizon parity stress (850 steps) source .venv/bin/activate CK_NUM_THREADS=8 python version/v7/scripts/train_parity_epochs_v7.py \ --epochs 8 --seq-len 8 --total-tokens 131072 --grad-accum 8 \ --optimizer adamw --lr 1e-3 --seed 42 --max-steps 850 \ --ck-rmsnorm-backend c --ck-swiglu-backend c --ck-loss-backend c \ --json-out /tmp/v7_ce_semantics_c_850_postfix.json # A/B against strict CE reference path source .venv/bin/activate CK_NUM_THREADS=8 python version/v7/scripts/train_parity_epochs_v7.py \ --epochs 8 --seq-len 8 --total-tokens 131072 --grad-accum 8 \ --optimizer adamw --lr 1e-3 --seed 42 --max-steps 850 \ --ck-rmsnorm-backend c --ck-swiglu-backend c --ck-loss-backend c_ptref \ --json-out /tmp/v7_ce_semantics_c_ptref_850_postfix.json
Bottom Line
CE is not just a textbook formula in production. Matching PyTorch semantics and numeric behavior at the reduction boundary is enough to decide whether long-horizon training stays within tolerance or diverges. In v7, those subtle CE corrections were necessary to make backprop robust.