Skip to content

Tool: evaluate layer-wise numerical-error propagation#525

Open
jlamypoirier wants to merge 25 commits into
mainfrom
jlp_evaluate_precision
Open

Tool: evaluate layer-wise numerical-error propagation#525
jlamypoirier wants to merge 25 commits into
mainfrom
jlp_evaluate_precision

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 26, 2026

Summary

  • New tools/evaluate_precision.py — inherits PretrainedGPTModelConfig (so model: and pretrained: are real typed Config fields) and adds variants:, output_dir:, and a few optional knobs. Runs a fp32 reference plus one trainer invocation per variant in-process; captures per-layer forward activations and input gradients via the standard tensor-logs pipeline; emits per-tensor RMS / max diffs as a console table + precision_report.json.
  • Variants aren't dtype-only: each is a flat dict of dotted-path overrides (same syntax as Fast-LLM CLI key=value args) so a variant can sweep any config knob — attention implementation, optimizer dtype, fused vs unfused, etc.
  • Per-variant trainer configs are built with TrainerConfig.get_subclass(...).from_dict(base, fp32_dtypes, variant_updates, tool_overrides). Tuple-keyed updates compose in precedence order: forced fp32 → variant overrides (which can re-override fp32) → tool-required debug-logging overrides (which always win).
  • Training, optimizer, and data sections of the trainer config are hardcoded inside the tool (single iteration, no checkpoint save, random tokens, LR 0, fp32 optimization dtype). Only knobs that affect the propagation measurement are user-facing: model, pretrained, variants, output_dir, num_samples, micro_batch_size, sequence_length.
  • Moves compare_tensor_logs.py from tests/utils/ into fast_llm/engine/config_utils/ so it's importable from tools/, and factors a _compute_diff helper out of CompareConfig.compare_tensors so the tool can extract numbers for every tensor — not only those that breach a tolerance. Three test callers updated; behaviour unchanged.
  • Fills in the HF metadata allowlist (fast_llm/engine/checkpoint/huggingface.py) with the generic PretrainedConfig keys newer transformers versions serialize: generation defaults, encoder-decoder flags, family markers, torchscript, is_decoder, etc. Without this, loading any modern HF Llama checkpoint trips the coverage walker. None are architecture knobs Fast-LLM consumes.

Usage

python -m tools.evaluate_precision -c tool.yaml
pretrained:
  path: /path/to/local/hf/snapshot
  format: llama
output_dir: /tmp/precision_eval
variants:
  bf16:
    model.distributed.compute_dtype: bfloat16
  bf16_sdpa:
    model.distributed.compute_dtype: bfloat16
    model.base_model.decoder.block.mixer.implementation: sdpa

Fast-LLM's HF loader reads weights from a local directory, so HF Hub IDs need to be snapshot_download'd first. model: and pretrained: can also be combined — pretrained provides architecture+weights, model: overrides individual fields.

Test plan

  • Cluster smoke test on a real HF checkpoint (SmolLM2-135M, snapshot via huggingface_hub). Reference fp32 + bf16 variant ran end-to-end; per-layer RMS/max table populated for all 30 decoder layers + embeddings + head, fw + bw; JSON artifact round-trips through json.load. Output shows propagated error growing with depth, with sharp jumps at layers where activation magnitude regime changes (e.g. ref_scale 6 → 777 around block 11, bf16 RMS rel 10% → 0.7% → back up to 13% at block 28).
  • Existing layer-comparison tests still pass with the moved compare_tensor_logs.py and the refactored compare_tensors.

🤖 Generated with Claude Code

jlamypoirier and others added 4 commits May 27, 2026 15:07
A new `tools/evaluate_precision.py` (`RunnableConfig`) drives a fp32
reference run plus one one-iteration trainer run per named variant from
a Fast-LLM training YAML, then extracts per-layer forward activations
and input gradients from the saved tensor logs and reports per-tensor
RMS and max diffs (absolute and scaled). Variants are flat dicts of
dotted-path overrides, the same syntax as Fast-LLM CLI key=value args,
so they can sweep arbitrary configuration knobs (dtype, attention
implementation, optimizer dtype, etc.) — not just compute_dtype.

Also moves `compare_tensor_logs.py` into the `fast_llm` package so it
is importable from `tools/` (the test tree isn't on sys.path for
script entry points), and factors a `_compute_diff` helper out of
`CompareConfig.compare_tensors` so the tool can extract numbers for
every tensor rather than only those that breach a tolerance. Existing
test callers are unaffected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool now takes a single YAML containing `pretrained:` (the
checkpoint that defines the model architecture + weights), `variants:`,
`output_dir:` and a few optional knobs (`model_type`, `num_samples`,
`micro_batch_size`, `sequence_length`). The training/optimizer/data
sections of the underlying training config are hardcoded — they have
no bearing on the propagation measurement (1 iteration, no checkpoint
save, random tokens, dummy learning rate, optimization dtype forced to
float32 alongside compute dtype). A variant can still override any of
the hardcoded fields via the dotted-path mechanism if needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool's input mirrors the trainer config's top-level shape: both
`model:` (FastLLMModelConfig dict) and `pretrained:` are user-facing,
and either or both may be set. Pretrained-from-HF is one config choice
among many — a user can also specify the architecture inline, or load
from HF and override individual fields.

The forced fp32 dtypes and tool-required debug levels are now applied
as overrides on top of whatever the user supplies, instead of being
hardcoded into the model section.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool now inherits from `PretrainedGPTModelConfig` so `model` and
`pretrained` are typed `FastLLMModelConfig` / `CheckpointLoadConfig`
fields rather than loose dicts — validated, autocompleted, and
introspectable like any other Fast-LLM config block.

Per-variant trainer configs are built with `TrainerConfig.get_subclass(...)
.from_dict(base, *updates)` instead of mutating a dict and round-tripping
through YAML. Updates use tuple-keyed dotted paths so forced-fp32,
variant overrides, and tool-required debug-logging overrides compose
cleanly in the right precedence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier force-pushed the jlp_evaluate_precision branch from 6431307 to 4c444d8 Compare May 27, 2026 19:16
jlamypoirier and others added 21 commits May 27, 2026 15:46
`transformers.PretrainedConfig.to_dict()` serializes a growing set of
generic defaults (generation knobs, family markers, encoder-decoder
flags). The Fast-LLM allowlist covered only a subset, so loading any
modern HF Llama checkpoint via `pretrained.format: llama` tripped the
coverage walker on keys like `torchscript`, `is_decoder`,
`is_llama_config`, `rope_interleaved`, and the full set of generation
defaults.

Fill in the missing entries, grouped by category. None of them are
architecture knobs that Fast-LLM consumes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop step / shape / max_rel columns, shorten the tensor name to the
description after the colon, reorder to Tensor / Kind / Relative /
Absolute / Max / Scale, format Relative as percent and the rest with
`.3g`. The JSON report keeps every field.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the separate Kind column and append `(fw)` / `(bw)` to the
shortened tensor name. Switch numeric formatting to fixed precision:
Relative shows `.2f` percent, Absolute / Max / Scale show `.2e`
scientific. Every column now lines up on a consistent digit count.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Scientific notation was overkill for values that mostly land between
0.01 and a few hundred. `.3f` is more readable while keeping the
per-column digit count consistent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Fast-LLM's `Run.__init__` picks the next free `runs/<n>` subdirectory
based on what already exists, but `_artifact_path` reads `runs/0`
unconditionally. Without this wipe, re-running the tool against the
same `output_dir` reads stale artifacts from the first invocation and
silently reports old numbers — even though the trainer correctly ran
with the new config.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a `data_path` field to the tool. When set, the tool lazily
generates a tokenized memmap dataset with random advantages and
old_logprobs at the given path (via the test helper
`tests/utils/dataset._get_test_dataset`) and uses it as the training
input. Required for policy-gradient losses like GSPO/GRPO that consume
those fields. Without it, the tool falls back to the random token
generator as before.

Console table now formats numeric columns with `.4g` so 1e-7-scale
GSPO gradients aren't rounded to zero while normal CE-magnitude
values still read as fixed-point numbers.

Rename `download_santacoder_tokenizer` to `download_test_tokenizer` —
it actually downloads the GPT-2 tokenizer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After the per-tensor tables, emit a short summary block per variant
showing first/last/max/median for forward and backward separately.
Aggregates over the intermediate layers per metric column (max and
median are computed per-column, so each row is a per-metric envelope
of the intermediate band rather than the metrics of any single layer).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Single compact table with one row per variant and columns for fw/bw
first/last/max/median Relative %. Max/median are over intermediate
layers (excluding first/last) when there is at least one intermediate
row.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rename `max`/`median` columns to `mid max`/`mid med` and add a header
note (`mid = excluding first/last`) so it's clear the aggregation
excludes the boundary layers. Also fix a column-collision bug where
labels at exactly the cell width touched without separator.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each variant now occupies two rows in the summary (fw on the first,
bw on the second), with the metric columns shared. Reads more
naturally and keeps the table half as wide. Percent precision goes
from .2f to .3f so single-digit-percent differences between variants
are visible.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Top header line groups columns under `fw` / `bw`; the second line
lists the per-pass aggregations. Aggregations are ordered
chronologically along the pass — first → mid med → mid max → last —
so reading left to right traces the propagation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `fp32_lm_head` field on `LanguageModelHeadConfig`. When `True`,
the LM head linear's input and weight are upcast to FP32 before the
matmul, matching vLLM's `bf16_last_layer_fp32` quantization. This lets
the trainer compute log-probabilities at the same numerical precision
as the actor's sampling, so the importance-sampling ratio starts near
1.0 instead of being artificially inflated by a trainer/actor precision
mismatch.

The detached FP32 weight has `requires_grad=False`, which makes
`output_parallel_linear_backward` skip the weight-grad path. The FSDP
gradient contract is restored by computing `grad_weight = grad.t() @ saved_input`
explicitly and accumulating into the original BF16 param's `grad_buffer`
via `accumulate_gradient`.

Off by default — disabled path is byte-identical to before.

Cherry-picked from #526 to unblock the precision-evaluation tool's
GSPO smoke test, which compares fp32_lm_head=true vs false.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Instead of generic `first` / `last` headers in the summary, use the
actual layer name pulled from the matching tensor's `Global <layer>
<kind>:` prefix. For the SmolLM2 smoke run that surfaces as
`embeddings` / `head` on fw and `head` / `decoder.0` on bw — directly
showing which layer the boundary values come from rather than making
the reader guess.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…den_states

Previously the only way to get a non-layer-output tensor (e.g. the LM head's
logits) into `tensor_logs` was to crank `model_debug_level`, which logs every
single `_debug`-emitted tensor (~700 per step for a 30-layer model).

Add a `MultiStageConfig.debug_hidden_states_log: list[str]` field — regex
patterns that get appended to each model input's `output_hidden_states` set.
Matching tensors are still populated into `kwargs[hidden_states]` (existing
contract for the HF inference wrapper); now they're also written to
`tensor_logs` so the precision tool can compare them across variants.

`_debug` already had the `output_hidden_state`-matched branch but only used it
to populate `kwargs[hidden_states]`. Extending it to also call
`log_distributed_tensor` at a fixed verbosity (13, matching the test
convention so samples are recorded) is a small gating change.

Plumbed through `GPTModel.get_preprocessing_config` →
`LanguageModelBatchPreprocessingConfig.output_hidden_states` →
`LanguageModelBatch.get_model_inputs`, which compiles the patterns and unions
them into each `LanguageModelInput.output_hidden_states`.

The precision tool now sets `[r"head\.logits"]` and surfaces logits as a
dedicated `logits` column on the fw side of the summary table.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The head's `logits` tensor has `requires_grad=False` (output of a
custom-autograd Function), so the existing `_debug(logits, ...)` could
only capture the forward value. Add a second `_debug(grad, "logits.grad",
...)` call right after the loss returns the explicit `dL/d_logits` so
the gradient is captured at the same fidelity. With the precision tool's
`output_hidden_states` pattern `r"head\.logits"`, both `head.logits`
and `head.logits.grad` end up in tensor_logs.

Tool summary surfaces both via dedicated `logits` columns — placed at
end-of-fw and start-of-bw chronologically. For GSPO the bw-logits column
reveals that the dL/dlogits computation itself is extremely precise
(~0.001% relative error), and the apparent backward noise actually
enters through the head matmul further downstream.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…alues

`.3f%` was rounding the bw-logits values down to 0.001%-0.000%, hiding
real signal. Switch to `.4g%` so values across 5 orders of magnitude
(0.0001% to ~20%) all render with meaningful precision; large values
keep 4 significant figures, tiny ones spell out their leading non-zero
digits or fall back to scientific.

Bw column order is now first / logits / mid med / mid max / last so
`logits` sits right after `head` (the first bw row) — semantically
the gradient at logits is what the head's backward consumes before
producing the gradient at its input.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Keep the prior `.3f%` default in the summary so most columns still
show `0.000%` / `12.672%` style values, but compute a per-column
decimal count based on the smallest non-zero value in that column —
bumping up just enough that every cell carries at least two
significant figures. Decimal count is uniform within a column.

For the GSPO run, only the bw-logits column hits the threshold and
gets bumped from 3 to 5 decimals, surfacing values like `0.00095%`
that previously rounded to `0.001%` or worse.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cell width drops from `max_label + 1` to `max_label`, inter-cell sep
from two spaces to one, group sep from four spaces to three. About 18
chars narrower on the GSPO smoke run with no loss of alignment or
readability.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lets `pretrained.path: org/model-id` resolve via huggingface_hub.snapshot_download
when not a local directory, matching transformers' from_pretrained behavior.
Local paths pass through unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two ready-to-run configs for tools/evaluate_precision: smol.yaml sweeps
precision-stability features (full_precision_gradients, full_precision_residual,
fp32_lm_head) on SmolLM2-135M; smol_gspo.yaml repeats the sweep with the GSPO
policy-gradient loss enabled.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A single forward+backward pass with micro_batch_size=1 has no gradient
accumulation, so toggling full_precision_gradients produces bit-identical
results to the bf16 baseline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

Sample precision-evaluation runs

Output of python -m tools.evaluate_precision -c examples/evaluate_precision/<config> on SmolLM2-135M, sequence length 128, single forward+backward step. Numbers are RMS relative diff vs the forced-fp32 reference, in %.

smol.yaml — pretrained HF weights

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits  bw mid med  bw mid max  bw decoder.0
bf16               0.000%      1.192%      19.904%     43.910%    4.597%      20.375%  14.710%    17.426%     22.495%     14.725%
bf16_fp32_lm_head  0.000%      1.192%      19.904%     43.901%    4.673%      19.559%  15.259%    16.797%     22.118%     14.375%
bf16_fp32_residual 0.000%      0.260%      4.569%      5.348%     0.568%      5.768%   4.132%     4.298%      8.025%      4.401%
bf16_max_precision 0.000%      0.260%      4.569%      5.347%     0.353%      6.653%   4.909%     4.959%      7.643%      4.381%

smol.yaml — random init (pretrained.model_weights=False)

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits  bw mid med  bw mid max  bw decoder.0
bf16               0.168%      1.739%      2.334%      2.425%     0.160%      0.284%   0.621%     2.120%      2.861%      3.188%
bf16_fp32_lm_head  0.168%      1.739%      2.334%      2.421%     0.160%      0.284%   0.617%     2.139%      2.937%      3.196%
bf16_fp32_residual 0.168%      1.372%      1.603%      1.689%     0.040%      0.295%   0.447%     1.435%      2.179%      2.321%
bf16_max_precision 0.168%      1.372%      1.603%      1.686%     0.041%      0.295%   0.434%     1.437%      2.180%      2.232%

smol_gspo.yaml — pretrained HF weights

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits  bw mid med  bw mid max  bw decoder.0
bf16               0.000%      0.242%      12.672%     11.312%    10.852%     0.107%   0.00095%   0.125%      0.340%      5.407%
bf16_fp32_lm_head  0.000%      0.242%      12.672%     11.296%    10.800%     0.105%   0.00176%   0.123%      0.336%      5.403%
bf16_fp32_residual 0.000%      0.227%      6.190%      6.593%     2.798%      0.042%   0.00573%   0.031%      0.068%      0.994%
bf16_max_precision 0.000%      0.227%      6.190%      6.587%     4.650%      0.049%   0.00730%   0.053%      0.128%      2.123%

smol_gspo.yaml — random init (pretrained.model_weights=False)

Variant            embeddings  fw mid med  fw mid max  fw logits  fw head     bw head  bw logits   bw mid med  bw mid max  bw decoder.0
bf16               0.173%      1.783%      2.222%      2.152%     2.296%      0.0133%  0.000095%   0.023%      0.158%      4.387%
bf16_fp32_lm_head  0.173%      1.783%      2.222%      2.143%     2.301%      0.0134%  0.000095%   0.023%      0.163%      4.626%
bf16_fp32_residual 0.173%      1.356%      1.625%      1.566%     0.860%      0.0048%  0.000044%   0.011%      0.081%      2.210%
bf16_max_precision 0.173%      1.356%      1.625%      1.560%     0.939%      0.0057%  0.000045%   0.012%      0.091%      2.611%

Observations

  • Pretrained weights produce much larger forward-pass errors than random init — particularly visible at fw mid max (single worst intermediate layer) and at head.logits. The CE loss config peaks around 20-44%, GSPO around 11-13%. Random init keeps everything under ~3%.
  • full_precision_residual is the dominant stability lever — it cuts the worst forward-pass numbers in roughly half (pretrained CE) or by a smaller fraction (random / GSPO).
  • fp32_lm_head (Add fp32_lm_head flag for vLLM precision parity #526) has limited effect on its own; it visibly helps only when combined with full_precision_residual and even there mostly on the absolute head output (not on logits).
  • GSPO backward errors are far smaller than CE backward errors (e.g. bw logits at 1e-3-1e-5% vs ~15% for CE), consistent with the GSPO loss producing much smaller logit gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant