pytorch
f1978b18 - add mixed data type support for LayerNorm (#81851)

Commit
2 years ago
add mixed data type support for LayerNorm (#81851) 1. If user uses amp to run bfloat16 models, `torch.autocast` will keep module paramters in acc dtype which will leave `gamma` and`beta` in float while input/output will be in bfloat16. 2. If user explicitly cast the model to bfloat16 such as: ``` x = torch.randn(n, t, c).bfloat16() ln = nn.LayerNorm(c).bfloat16() y = ln(x) ``` The input/output and gamma/beta will all be in bfloat16. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81851 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading