Fix mixed fused layer norm to mimick nn.LayerNorm for torch>1.11 #281
If pytorch>1.11 available we can use nn.LayerNorm instead of MixedLay…
26bd3d2b
thomasw21
force pushed
from
7a3a32ac
to
26bd3d2b
3 years ago
Add MixedFusedLayerNorm fix
79922d4a
Woops
27991b8d
Convert weight/bias only once
c4b05ee9
Revert "Convert weight/bias only once"
5db805c3
Turns out LayerNorm for bf16 is slower using torch==1.11
824b9c5f
Woops
febce3ca
Rewrite if condition
400ec42f
stas00
commented
on 2022-04-30
Use version package instead
f7d4e779
Test for LayerNorm
2510497d
Improve test to use torch_assert_equal + minor fixes
99867bd0
Force bfloat16
d84fc727
Woops
723489bb
thomasw21
changed the title Remove mixed fused layer norm in favor of nn.LayerNorm Fix mixed fused layer norm to mimick nn.LayerNorm for torch>1.11 3 years ago
stas00
commented
on 2022-05-02
stas00
approved these changes
on 2022-05-03
Fix torch version comparison
37500d97
thomasw21
merged
908dc9cb
into main 3 years ago
thomasw21
deleted the thomas/remove_mixed_fused_layer_norm branch 3 years ago
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub