pytorch
44282189 - [primtorch] Added `native_group_norm` decomp (#78029)

Commit
2 years ago
[primtorch] Added `native_group_norm` decomp (#78029) cc: @jansel @bertmaher More or less identical in spirit to the layer norm and batch norm ones. One annoying thing about all 3 of these is that layer_norm has slightly different `mean/var` semantics than batch norm and group norm. After normalization, `layer_norm` keeps them unsqueezed (so they're something like [1, 5, 1, 1]) while batch norm and group norm squeeze out the 1-dims. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78029 Approved by: https://github.com/bertmaher
Author
Committer
Parents
Loading