pytorch
98c85501 - Fix Triplet Margin Loss Opinfo (#110302)

Commit
1 year ago
Fix Triplet Margin Loss Opinfo (#110302) Triplet Margin Loss takes in a Callable `distance_function` parameter which is not supported as an argument on the fx graph. See previous error: > File "/scratch/eellison/work/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/eellison/work/pytorch/torch/_dynamo/variables/torch.py", line 723, in call_function *proxy_args_kwargs(args, kwargs), File "/scratch/eellison/work/pytorch/torch/_dynamo/utils.py", line 504, in proxy_args_kwargs f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}" File "/scratch/eellison/work/pytorch/torch/_dynamo/exc.py", line 143, in unimplemented raise Unsupported(msg) torch._dynamo.exc.Unsupported: call_function args: TensorVariable() TensorVariable() TensorVariable() ConstantVariable(float) NNModuleVariable() This is fixable by just inlining into `triplet_margin_loss` and continuing to compile it. This required support for `has_torch_function_variadic`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110302 Approved by: https://github.com/mlazos
Author
Committer
Parents
Loading