Modernize ESMC rotary to the standard (cos, sin) + apply_rotary_pos_emb convention
Replace the flash-attn-style cache-based rotary with the standard
Transformers convention used by esm/llama, as the last code-shaping step
before authoring modular_esmc.py.
Before: a stateful `RotaryEmbedding` (per-attention-module) caching cos/sin,
`forward(q, k)` returning rotated tensors, a custom `_apply` override to keep
`inv_freq` fp32 across device casts, plus `_rotate_half` / `_apply_rotary_emb_torch`.
After:
- `rotate_half` + `apply_rotary_pos_emb` (identical to esm/llama).
- `ESMCRotaryEmbedding(config)` -> `(cos, sin)`, computed once in
`ESMCModel.forward` and threaded down (position_embeddings) through the
stack/block to attention, mirroring esm. `inv_freq` is fp32 and
non-persistent (matches the old behaviour: no rotary tensors in the
checkpoint), and cos/sin are built in fp32 then cast.
- Add `config.rope_theta` (default 10000.0, the previously-hardcoded base).
- `_init_weights` recomputes `inv_freq` for `ESMCRotaryEmbedding` (meta-init safe).
Verified: strict state_dict load from the saved baseline succeeds (rotary
buffers are non-persistent, so keys are unchanged -> published checkpoints
still load), and last_hidden_state is bit-identical (0.0) at all valid
positions for plain, padding-mask, and multi-chain. The fp32 matmul-based
freqs equal the old `outer(t, inv_freq)`; same RoPE math, idiomatic shape.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>