accelerate
Allow FSDP to use with `torch.autocast` for bfloat16 mixed precision
#2033
Merged

Allow FSDP to use with `torch.autocast` for bfloat16 mixed precision #2033

brcps12
brcps121 year ago👍 2

What does this PR do?

FSDP supports mixed precision using MixedPrecision class, it does not need to wrap forward function with torch.autocast.

The code statement of ignoring this wrapping was added at accelerate v0.22.0, but now removed at v0.23.0

Related PRs are:

I can't find any information about why it is added or removed.

In fact, mixed precision works well even without torch.autocast, and even if it is needed, it does not work properly in the current version.

So, I think it need to apply one of the following two options:

  1. Add self.distributed_type != DistributedType.FSDP in condition not to use torch.autocast
  2. Add DistributedType.FSDP in this file

The reason for 2 is that when FSDP is used, the distributed_type field is replaced with DistribytedType.FSDP in this line, so I think it needs to be added to support FSDP as well.

As a related issue, the MPT posted on Huggingface Hub uses the LPNorm class, but when learning with FSDP + bfloat16, the dtype changes before and after norm. It is occurred in version v0.23.0.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

brcps12 Ignore native_amp when FSDP is used
e0ffd4dd
muellerzr muellerzr requested a review from pacman100 pacman100 1 year ago
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

pacman100
pacman1001 year ago

FSDP supports mixed precision using MixedPrecision class, it does not need to wrap forward function with torch.autocast.

It is required else Mixed Precision FP16 fails with error RuntimeError: expected scalar type Half but found Float. See the failing tests because of which this was required: https://github.com/huggingface/accelerate/actions/runs/6079386828/job/16491799916

pacman100
pacman1001 year ago

How about going with the option 2 you mentioned? Does that solve the issue with MPT?

brcps12 Rollback condition
6407ffb3
brcps12 Fix mixed precision of bfloat16 for FSDP
6bf41ae7
brcps12
brcps121 year ago

It is required else Mixed Precision FP16 fails with error RuntimeError: expected scalar type Half but found Float. See the failing tests because of which this was required: https://github.com/huggingface/accelerate/actions/runs/6079386828/job/16491799916

Looking at the problem, it seems that FSDP's MixedPrecision only supports between FSDP modules, not torch's operator (softmax -> matmul). Thank you for sharing!

How about going with the option 2 you mentioned? Does that solve the issue with MPT?

Yes. The issue has fixed. So, I'm gonna working with option 2 and changes PR title.

brcps12 brcps12 changed the title Ignore torch.autocast for mixed precision when FSDP is used Fix mixed precision of bfloat16 when using FSDP 1 year ago
brcps12 brcps12 changed the title Fix mixed precision of bfloat16 when using FSDP Add FSDP allowed to wrap with `torch.autocast` for bfloat16 mixed precision 1 year ago
brcps12 brcps12 changed the title Add FSDP allowed to wrap with `torch.autocast` for bfloat16 mixed precision Allow FSDP to use with `torch.autocast` for bfloat16 mixed precision 1 year ago
pacman100
pacman100 approved these changes on 2023-10-06
pacman1001 year ago😄 1

Thank you @brcps12 for fixing the bug wrt bf16 autocasting when using FSDP, LGTM! 🤗

pacman100 pacman100 merged 5ae61111 into main 1 year ago

Login to write a write a comment.

Login via GitHub

Reviewers
Assignees
No one assigned
Labels
Milestone