1623 | 1624 | ||
1624 | 1625 | @classmethod | |
1625 | def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: | ||
1626 | def _check_and_enable_sdpa( |
Can you update as well
./src/transformers/models/idefics/modeling_idefics.py:954: # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
./src/transformers/models/idefics/modeling_idefics.py:956: def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
./src/transformers/models/falcon/modeling_falcon.py:954: # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
./src/transformers/models/falcon/modeling_falcon.py:956: def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
If this function is in multiple files, I guess it makes sense to handle the logic in _autoset_attn_implementation
You changed the signature hence my suggestion, but I see you updated since then
1655 | |||
1656 | if torch.version.hip is not None and config._attn_implementation == "sdpa" and device_map == "auto": | ||
1657 | logger.warning_once( | ||
1658 | "Using the `SDPA` attention implementation with `device_map='auto'` on a ROCM device may lead to performance issues due to the FA backend. Disabling it to use alternative backends." |
Did not have time for this, but ideally we should have a pytorch issue open with a repro without transformers, and link it here
I agree. But not sure if it is also bind to auto
in transformers
.
Also, it might be good to put all of these into a AMD documentation, so we can share with them.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Works for me overall. Just as @fxmarty mentioned, 2 places need to be updated too. I will leave you two to come up with a conclusion for that.
ping me again once you think a final review is necessary (i.e. the changes are big enough than current one)
Ready for a merge @mht-sharma ?
Login to write a write a comment.
What does this PR do?
Garbage values may occur during model generation with models like LLama, Mistral, and Mixtral, particularly when utilizing multi-gpu setup and
device_map=auto
alongsideSDPA
andFA
.The PR disables the FA on SDPA for ROCM devices