transformers
Disable the FA backend for SDPA on AMD GPUs
#30850
Merged

Disable the FA backend for SDPA on AMD GPUs #30850

mht-sharma
mht-sharma363 days ago

What does this PR do?

Garbage values may occur during model generation with models like LLama, Mistral, and Mixtral, particularly when utilizing multi-gpu setup and device_map=auto alongside SDPA and FA.

The PR disables the FA on SDPA for ROCM devices

mht-sharma disable fa
64664da8
mht-sharma disable fa
53c438e6
mht-sharma update warning
64bd948d
mht-sharma mht-sharma requested a review from fxmarty fxmarty 363 days ago
mht-sharma
mht-sharma363 days ago
mht-sharma mht-sharma marked this pull request as ready for review 363 days ago
fxmarty
fxmarty commented on 2024-05-16
src/transformers/modeling_utils.py
16231624
16241625 @classmethod
1625 def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
1626
def _check_and_enable_sdpa(
fxmarty363 days ago

Can you update as well

./src/transformers/models/idefics/modeling_idefics.py:954:    # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
./src/transformers/models/idefics/modeling_idefics.py:956:    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
./src/transformers/models/falcon/modeling_falcon.py:954:    # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
./src/transformers/models/falcon/modeling_falcon.py:956:    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
mht-sharma363 days ago

If this function is in multiple files, I guess it makes sense to handle the logic in _autoset_attn_implementation

fxmarty363 days ago (edited 363 days ago)

You changed the signature hence my suggestion, but I see you updated since then

Conversation is marked as resolved
Show resolved
src/transformers/modeling_utils.py
16501653 if not hard_check_only:
16511654 config._attn_implementation = "sdpa"
1655
1656
if torch.version.hip is not None and config._attn_implementation == "sdpa" and device_map == "auto":
fxmarty363 days ago

Other device_map may fail as well no? e.g. if manually splitting the layers on several GPUs

mht-sharma363 days ago

Yes, it could happen on other types of device map where it could use multiple devices. Updated to check the torch.cuda.device_count() > 1.

src/transformers/modeling_utils.py
1655
1656 if torch.version.hip is not None and config._attn_implementation == "sdpa" and device_map == "auto":
1657 logger.warning_once(
1658
"Using the `SDPA` attention implementation with `device_map='auto'` on a ROCM device may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
fxmarty363 days ago

Did not have time for this, but ideally we should have a pytorch issue open with a repro without transformers, and link it here

ydshieh363 days ago

I agree. But not sure if it is also bind to auto in transformers.

ydshieh363 days ago

Also, it might be good to put all of these into a AMD documentation, so we can share with them.

HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev363 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

ydshieh
ydshieh approved these changes on 2024-05-16
ydshieh363 days ago

Works for me overall. Just as @fxmarty mentioned, 2 places need to be updated too. I will leave you two to come up with a conclusion for that.

ping me again once you think a final review is necessary (i.e. the changes are big enough than current one)

mht-sharma update warning
d90d45b1
fxmarty
fxmarty approved these changes on 2024-05-16
ydshieh
ydshieh363 days ago👍 1

Ready for a merge @mht-sharma ?

ydshieh ydshieh merged 0753134f into main 363 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone