pytorch
70d80fb4 - Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407)

Commit
3 years ago
Fixed type promotion semantics for native_batch_norm and native_layer_norm (#77407) Originally, when these were written, they simply used the naive strategy of "upcast all inputs to floats, and downcast all inputs back". In addition to being... not quite what the kernels did, they also didn't capture some additional semantics. Namely, that the norms (except for layer norm on CPU! cc: @ngimel) return fp32 for the mean and rstd values. Also, folks didn't like that I wrote `native_layer_norm` in terms of `native_batch_norm`. Which is fair - so I refactored the common logic into a `normalize` function. cc: @jansel / @bertmaher , who've been looking at lowering layer norm/batch norm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77407 Approved by: https://github.com/bertmaher
Author
Committer
Parents
Loading