[jit] fix traced training attribute (#47211)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47211
The attribute is getting shadowed by the default one set on all modules,
and the __setattr__ on the TracedModule object prevents setting it correctly.
import torch
inp = torch.zeros(1, 3, 224, 224)
model = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True)
model.eval()
print(model.training)
with torch.no_grad():
traced = torch.jit.trace(model, inp)
print(traced.training)
traced.eval()
print(traced.training)
traced.training = False
print(traced.training)
torch.jit.freeze(traced)
Test Plan: Imported from OSS
Reviewed By: suo
Differential Revision: D24686690
Pulled By: zdevito
fbshipit-source-id: 9c1678dc68e9bf83176e9f5a20fa8f6bff5d69a0