pytorch
84d3df80 - Fast cuda layer norm (#67977)

Commit
4 years ago
Fast cuda layer norm (#67977) Summary: This adds apex-inspired fast layer norm forward kernel to pytorch (it is a significant rewrite though). It's much faster than current implementation, for a typical transformer size (32*196, 1024) time goes down from ~180us to ~49 us on Volta. Compared to apex, it also produces bitwise accurate results between float inputs representable in fp16, and fp16 inputs. It produces slightly different results compared to current implementation though, because welford summation is implemented differently. It is slower than lightSeq (~37 us), but lightseq uses inaccurate variance approximation, and doesn't guarantee float - fp16 bitwise accuracy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67977 Reviewed By: mruberry Differential Revision: D32285331 Pulled By: ngimel fbshipit-source-id: a8b876a9cf3133daacfe0ce3a37e3ad566f4b6a8
Author
Natalia Gimelshein
Parents
Loading