ArgumentStash for Scalar arguments (#21931)
Summary:
Scalars are being traced as constants.
This PR is to fix this issue.
The ONNX Graph for Test_Full_op() before and after this change:
def Test_Full_op():
class Test_Full(nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
model = Test_Full()
x = torch.tensor(12)
output = model(x)
Before this change:
graph(%input1 : Long()):
%output1 : Float(3, 4) = onnx::Constant[value=<Tensor>]
return (%output1)
After this change:
graph(%input1 : Long()):
%1 : int[] = onnx::Constant[value= 3 4 [ Variable[CPULongType]{2} ]]
%2 : Tensor = onnx::ConstantOfShape[value={0}]
%output1 : Float(3, 4) = onnx::Add(%2, %input1)
return (%output1)
Similar PR : https://github.com/pytorch/pytorch/pull/12939
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21931
Reviewed By: zrphercule
Differential Revision: D15950066
Pulled By: houseroad
fbshipit-source-id: 3470665d88fa34faa600940ef16b069a06002cd5