pytorch
35b3e160 - [pytorch] Fix torch.nn.functional.normalize to be properly scriptable (#51909)

Commit
3 years ago
[pytorch] Fix torch.nn.functional.normalize to be properly scriptable (#51909) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51909 Several scenarios don't work when trying to script `F.normalize`, notably when you try to symbolically trace through it with using the default argument: ``` import torch.nn.functional as F import torch from torch.fx import symbolic_trace def f(x): return F.normalize(x) gm = symbolic_trace(f) torch.jit.script(gm) ``` which leads to the error ``` RuntimeError: normalize(Tensor input, float p=2., int dim=1, float eps=9.9999999999999998e-13, Tensor? out=None) -> (Tensor): Expected a value of type 'float' for argument 'p' but instead found type 'int'. : def forward(self, x): normalize_1 = torch.nn.functional.normalize(x, p = 2, dim = 1, eps = 1e-12, out = None); x = None ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE return normalize_1 Reviewed By: jamesr66a Differential Revision: D26324308 fbshipit-source-id: 30dd944a6011795d17164f2c746068daac570cea
Author
Brandon Lin
Parents
Loading