xla
50043515 - Move requires_jax to inner flash_attention functions

Commit
357 days ago
Move requires_jax to inner flash_attention functions Move `requires_jax` decorators to `_fa_custom_forward_single_device` and `_fa_custom_backward_single_device` to avoid dynamo graph breaks when using flash attention kernel.
Author
Parents
Loading