Fix scaling and allgather with `torch.autocast` (#7534)
This PR includes these two fixes:
- Use GradScaler only for FP16 (not for BF16)
- Fix dtype conversion for ZeRO3 allgather
- The reduce hook should be called only once, even when a parameter is
shared across multiple layers (tied parameters).
- Currently, the hook is triggered at each tied layer because we
temporarily set `.data` with a different dtype.
- The fix ensures that the parameter consistently retains the same
dtype.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Signed-off-by: Stas Bekman <stas@stason.org>
Signed-off-by: jakehemmerle <jakehemmerle@protonmail.com>
Signed-off-by: Qi Bin <qibin0506@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Jake Hemmerle <jakehemmerle@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Qi Bin <qibin0506@users.noreply.github.com>