transformers
cc15f3cd - Let kernel modules declare their preferred mask function

Commit
21 days ago
Let kernel modules declare their preferred mask function `load_and_register_attn_kernel` hardcodes the mask function to `flash_attention_2` for all custom attention kernels. This is incorrect for kernels that need a different mask type (e.g., SDPA-style masks). Add support for a `MASK_FUNCTION` module-level attribute on kernel packages. If present, it specifies which mask type to use (e.g., "sdpa", "eager"). Falls back to "flash_attention_2" for backward compatibility when the attribute is absent. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Author
Committer
Parents
Loading