Add support for batch_norm fusion to the JIT (#15146)
Summary:
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.
Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
This commit gives us a 7% end-to-end speedup for ResNet50 with batch size 32. Note that this only applies to inference mode at the moment due to lack of AD support for CNN operations (I'll be adding that soon), and not to the standard `torchvision` models, because they use in-place ops which aren't supported by the fuser (we need a way of proving that de-inplacing them is safe).
cc zou3519 zdevito mruberry ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15146
Differential Revision: D13548303
Pulled By: zou3519
fbshipit-source-id: a2e2e5abc383f637fae19bd1b423f20c2cbc056a