Route ESMFold2 plain self-attention through the v5 attention interface
SWA3DRoPEAttention's plain softmax(QKᵀ)V core now dispatches through
ALL_ATTENTION_FUNCTIONS / a local eager_attention_forward, keyed on
config._attn_implementation, with the sliding window expressed as an
additive attention mask. The custom flash-attention path (native
bidirectional window_size + varlen for packed inputs) is kept as an
opt-in backend, now gated on _attn_implementation == "flash_attention_2"
instead of auto-selecting whenever flash-attn is importable — so the
default is sdpa (matching the fork's SDPA fallback bit-for-bit) and
flash is opt-in, per v5 conventions.
ESMFold2Model declares _supports_sdpa / _supports_flash_attn /
_supports_attention_backend and, after construction, attaches its shared
config to every SWA3DRoPEAttention (the atom encoders/decoders build them
from explicit dims), so dispatch stays live under set_attn_implementation.
is_causal=False guards against the interface defaulting to causal when no
mask is passed. Pair-bias (AttentionPairBias) and triangular math are left
untouched.
Validated on CPU vs the pre-refactor forward (random weights): sdpa
max|Δ|=0.0 (bit-exact), eager max|Δ|=1.3e-3 (bf16 softmax precision).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>