Proper_flex (#36643)
* proper performant flex attention implementation
* wrapper for flex attention to compile only when triggered
* wrapper for flex attention to compile only when triggered
* attention mask type detection
* Update src/transformers/integrations/flex_attention.py
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
* nit
* nit
* nit
* nit
* gemma2 support
* add citation for torchtune
* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update flex_attention.py
* nit
* nit
* nit
* reset gemma2 modifications
* nit
* nit
* nit
* licencing
* apply changes to other models
* safe import
---------
Co-authored-by: Sung Ching Liu <sunny19981005@outlook.com>
Co-authored-by: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>