pytorch
33760505 - fix type promotion for group_norm composite C++ kernel (#86607)

Commit
2 years ago
fix type promotion for group_norm composite C++ kernel (#86607) python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here. When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86607 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading