Flash Attention 2 support for RoCm (#27611)
* support FA2
* fix typo
* fix broken tests
* fix more test errors
* left/right
* fix bug
* more test
* typo
* fix layout flash attention falcon
* do not support this case
* use allclose instead of equal
* fix various bugs with flash attention
* bump
* fix test
* fix mistral
* use skiptest instead of return that may be misleading
* add fix causal arg flash attention
* fix copies
* more explicit comment
* still use self.is_causal
* fix causal argument
* comment
* fixes
* update documentation
* add link
* wrong test
* simplify FA2 RoCm requirements
* update opt
* make flash_attn_uses_top_left_mask attribute private and precise comment
* better error handling
* fix copy & mistral
* Update src/transformers/modeling_utils.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update src/transformers/modeling_utils.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update src/transformers/modeling_utils.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update src/transformers/utils/import_utils.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210
* fix merge
* simplify
* inline args
---------
Co-authored-by: Felix Marty <felix@hf.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>