Fix torch.arange traced as constant (#25363)
Summary:
torch.arange is always traced as a constant which makes it impossible to trace correctly TestModel() from the example below.
class TestModel(torch.nn.Module):
def forward(self, input):
return torch.arange(input.shape[0])
input = torch.randn(5,3,2)
print(torch.jit.trace(TestModel(), input).graph)
Currently the trace of TestModel() looks like:
graph(%self : ClassType<TestModel>,
%input : Float(5, 3, 2)):
%11 : int = prim::Constant[value=5]()
%12 : int = prim::Constant[value=4]()
%13 : int = prim::Constant[value=0]()
%14 : Device = prim::Constant[value="cpu"]()
%15 : bool = prim::Constant[value=0]()
%16 : Long(5) = aten::arange(%11, %12, %13, %14, %15)
return (%16)
This PR will allow the trace to have a variable value for %11.
The trace of TestModel() with this PR's modifs looks like:
graph(%self : ClassType<TestModel>,
%input : Float(5, 3, 2)):
%2 : int = prim::Constant[value=0]()
%3 : int = aten::size(%input, %2)
%4 : Long() = prim::NumToTensor(%3)
%11 : Scalar = prim::ImplicitTensorToNum(%4)
%12 : int = prim::Constant[value=4]()
%13 : int = prim::Constant[value=0]()
%14 : Device = prim::Constant[value="cpu"]()
%15 : bool = prim::Constant[value=0]()
%16 : Long(5) = aten::arange(%11, %12, %13, %14, %15)
return (%16)
More info : https://github.com/pytorch/pytorch/issues/20075
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25363
Reviewed By: zrphercule
Differential Revision: D17301934
Pulled By: houseroad
fbshipit-source-id: d9907763742cb51d8c761bf63fc2e4918f7b9941