Fast cuda layer norm (#67977)
Summary:
This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though).
It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently.
It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67977
Reviewed By: mruberry
Differential Revision: D32285331
Pulled By: ngimel
fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8