transformers
5bbaa6e1 - Modernize ESMC rotary to the standard (cos, sin) + apply_rotary_pos_emb convention

Commit
2 days ago
Modernize ESMC rotary to the standard (cos, sin) + apply_rotary_pos_emb convention Replace the flash-attn-style cache-based rotary with the standard Transformers convention used by esm/llama, as the last code-shaping step before authoring modular_esmc.py. Before: a stateful `RotaryEmbedding` (per-attention-module) caching cos/sin, `forward(q, k)` returning rotated tensors, a custom `_apply` override to keep `inv_freq` fp32 across device casts, plus `_rotate_half` / `_apply_rotary_emb_torch`. After: - `rotate_half` + `apply_rotary_pos_emb` (identical to esm/llama). - `ESMCRotaryEmbedding(config)` -> `(cos, sin)`, computed once in `ESMCModel.forward` and threaded down (position_embeddings) through the stack/block to attention, mirroring esm. `inv_freq` is fp32 and non-persistent (matches the old behaviour: no rotary tensors in the checkpoint), and cos/sin are built in fp32 then cast. - Add `config.rope_theta` (default 10000.0, the previously-hardcoded base). - `_init_weights` recomputes `inv_freq` for `ESMCRotaryEmbedding` (meta-init safe). Verified: strict state_dict load from the saved baseline succeeds (rotary buffers are non-persistent, so keys are unchanged -> published checkpoints still load), and last_hidden_state is bit-identical (0.0) at all valid positions for plain, padding-mask, and multi-chain. The fp32 matmul-based freqs equal the old `outer(t, inv_freq)`; same RoPE math, idiomatic shape. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Author
Parents
Loading