transformers
7c5b8c89 - Make ESMFold2 dtype-honest: drop in-model autocast, support from_pretrained(dtype=)

Commit
14 hours ago
Make ESMFold2 dtype-honest: drop in-model autocast, support from_pretrained(dtype=) ESMFold2 was the only ported model relying on in-model `torch.amp.autocast` (`use_amp = device.type=="cuda"`) to (a) run matmuls in bf16 and (b) keep norms/softmax in fp32. That made it GPU-only for bf16 and silently broke CPU and `from_pretrained(dtype=bfloat16)`. Replace it with the standard Transformers idiom — the model runs in its loaded dtype, with numerically-sensitive ops explicitly pinned to fp32: - Add an fp32-pinned `LayerNorm(nn.LayerNorm)` (computes in fp32, returns the input dtype, keeps fp32 weights under `from_pretrained(dtype=bf16)`) and use it for every `nn.LayerNorm`; the affine-free `RMSNorm`/`F.rms_norm` stay in the activation dtype (they ran bf16 under autocast too). Pin the remaining softmaxes with `dtype=torch.float32`. - Remove the three in-model autocast regions (trunk, confidence trunk, ESMC `_lm_precision_context`) and convert the Kabsch/SVD `enabled=False` islands to plain `.float()`. - The model interleaves fp32 islands (geometry, coords, one-hot/continuous input features, the fp32-accumulated confidence pair, the fp32 triangular contract) with bf16 compute, which autocast used to bridge; add explicit `.to(dtype)` casts at each island→matmul boundary (atom encoder, inputs embedder, rel_pos, token_bonds, distogram head, MSA embed, diffusion s/z/coords conditioning, confidence pair). Also align the ESMC hidden states to the LM projection dtype, which additionally lets a bf16 backbone feed an fp32 trunk (no more `esmc_precision="fp32"` requirement on CPU). Validated vs the autocast baseline (ubiquitin, seed 0): GPU bf16 0.798/0.737, GPU fp32 0.800/0.740, CPU fp32 0.801/0.740 — all match 0.80/0.74. CPU bf16 runs correctly too (its matmul accumulates in fp32, ~4e-3 rel-err), just slower. make typing / check_docstrings / check_config_attributes / ruff clean; 7 tests + the @slow GPU integration test (now bf16) pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Author
Committer
Parents
Loading