pytorch
b8679ee1 - fix conv+bn folding issue when bn hasn't running states (#71259)

Commit
2 years ago
fix conv+bn folding issue when bn hasn't running states (#71259) Summary: Doing conv+bn folding which bn hasn't a running stats, there have error for JIT and FX path: ``` import torch import torch.nn as nn import torch.fx.experimental.optimization as optimization class M(nn.Module): def __init__(self): super(M, self).__init__() self.conv = nn.Conv2d(32, 64, 3, stride=2) self.bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) def forward(self, x): x = self.conv(x) x = self.bn(x) return x x = torch.randn([1, 32, 50, 50]) model = M().eval() ''' # jit path with torch.no_grad(): traced = torch.jit.trace(model, x).eval() traced = torch.jit.freeze(traced) ''' # FX path fused_model = optimization.fuse(model) ``` expected result: 1. JIT path ``` Traceback (most recent call last): File "bn_test.py", line 27, in <module> traced = torch.jit.freeze(traced) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 119, in freeze run_frozen_optimizations(out, optimize_numerics, preserved_methods) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 167, in run_frozen_optimizations torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics) RuntimeError: Expected Tensor but got None ``` 2. FX path ``` Traceback (most recent call last): File "bn_test.py", line 31, in <module> model = optimization.fuse(model, inplace=True) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/fx/experimental/optimization.py", line 71, in fuse fused_conv = fuse_conv_bn_eval(conv, bn) File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 11, in fuse_conv_bn_eval fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 23, in fuse_conv_bn_weights bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) TypeError: unsupported operand type(s) for +: 'NoneType' and 'float' ``` This PR will fix this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71259 Reviewed By: anjali411 Differential Revision: D33595049 Pulled By: davidberard98 fbshipit-source-id: 0fe56bb2bb25d6d54ebc53789d2ad22458da9012 (cherry picked from commit 5672c083784585e6e1ec5657f02bd3051afb2b50)
Author
Committer
Parents
Loading