pytorch
4bf22fcf - add mixed data type support for GroupNorm (#81852)

Commit
2 years ago
add mixed data type support for GroupNorm (#81852) 1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16, the input/output and gamma/beta will all be in bfloat16. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81852 Approved by: https://github.com/jgong5, https://github.com/malfet
Author
Committer
Parents
Loading