transformers
5fe5ac45 - Route ESMFold2 plain self-attention through the v5 attention interface

Commit
1 day ago
Route ESMFold2 plain self-attention through the v5 attention interface SWA3DRoPEAttention's plain softmax(QKᵀ)V core now dispatches through ALL_ATTENTION_FUNCTIONS / a local eager_attention_forward, keyed on config._attn_implementation, with the sliding window expressed as an additive attention mask. The custom flash-attention path (native bidirectional window_size + varlen for packed inputs) is kept as an opt-in backend, now gated on _attn_implementation == "flash_attention_2" instead of auto-selecting whenever flash-attn is importable — so the default is sdpa (matching the fork's SDPA fallback bit-for-bit) and flash is opt-in, per v5 conventions. ESMFold2Model declares _supports_sdpa / _supports_flash_attn / _supports_attention_backend and, after construction, attaches its shared config to every SWA3DRoPEAttention (the atom encoders/decoders build them from explicit dims), so dispatch stays live under set_attn_implementation. is_causal=False guards against the interface defaulting to causal when no mask is passed. Pair-bias (AttentionPairBias) and triangular math are left untouched. Validated on CPU vs the pre-refactor forward (random weights): sdpa max|Δ|=0.0 (bit-exact), eager max|Δ|=1.3e-3 (bf16 softmax precision). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Author
Parents
Loading