Make ESMFold2 dtype-honest: drop in-model autocast, support from_pretrained(dtype=)
ESMFold2 was the only ported model relying on in-model `torch.amp.autocast`
(`use_amp = device.type=="cuda"`) to (a) run matmuls in bf16 and (b) keep
norms/softmax in fp32. That made it GPU-only for bf16 and silently broke CPU and
`from_pretrained(dtype=bfloat16)`. Replace it with the standard Transformers
idiom — the model runs in its loaded dtype, with numerically-sensitive ops
explicitly pinned to fp32:
- Add an fp32-pinned `LayerNorm(nn.LayerNorm)` (computes in fp32, returns the
input dtype, keeps fp32 weights under `from_pretrained(dtype=bf16)`) and use it
for every `nn.LayerNorm`; the affine-free `RMSNorm`/`F.rms_norm` stay in the
activation dtype (they ran bf16 under autocast too). Pin the remaining
softmaxes with `dtype=torch.float32`.
- Remove the three in-model autocast regions (trunk, confidence trunk, ESMC
`_lm_precision_context`) and convert the Kabsch/SVD `enabled=False` islands to
plain `.float()`.
- The model interleaves fp32 islands (geometry, coords, one-hot/continuous input
features, the fp32-accumulated confidence pair, the fp32 triangular contract)
with bf16 compute, which autocast used to bridge; add explicit `.to(dtype)`
casts at each island→matmul boundary (atom encoder, inputs embedder, rel_pos,
token_bonds, distogram head, MSA embed, diffusion s/z/coords conditioning,
confidence pair). Also align the ESMC hidden states to the LM projection dtype,
which additionally lets a bf16 backbone feed an fp32 trunk (no more
`esmc_precision="fp32"` requirement on CPU).
Validated vs the autocast baseline (ubiquitin, seed 0): GPU bf16 0.798/0.737,
GPU fp32 0.800/0.740, CPU fp32 0.801/0.740 — all match 0.80/0.74. CPU bf16 runs
correctly too (its matmul accumulates in fp32, ~4e-3 rel-err), just slower.
make typing / check_docstrings / check_config_attributes / ruff clean; 7 tests +
the @slow GPU integration test (now bf16) pass.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>