pytorch
e8836759 - [export] Add effect token to export (#121424)

Commit
302 days ago
[export] Add effect token to export (#121424) Following the creation of effect tokens (https://github.com/pytorch/pytorch/pull/120296), we want to now add support for these tokens in export because the calling/returning convention has changed. The inputs are now `(tokens, params, buffers, constants, user_inputs)` and the outputs are `(tokens, buffer_mutations, user_mutations, user_outputs)`. The graph looks something like: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %attr : [num_users=2] = placeholder[target=attr] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %with_effects : [num_users=2] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, _TorchScriptTesting.takes_foo.default, %attr, %arg1_1), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 1), kwargs = {}) %with_effects_1 : [num_users=2] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%getitem, _TorchScriptTesting.takes_foo.default, %attr, %getitem_1), kwargs = {}) %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects_1, 0), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects_1, 1), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %getitem_3), kwargs = {}) return (getitem_2, add) ``` During unlifting, we will first remove the tokens and with_effect calls using the `remove_effect_tokens` pass. (cc @SherlockNoMad on the pass to remove tokens). This is so that this won't change the calling conventions when retracing. The graph after unlifting looks something like: ``` graph(): %attr_1 : [num_users=2] = get_attr[target=attr] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %takes_foo_default_1 : [num_users=1] = call_function[target=torch.ops._TorchScriptTesting.takes_foo.default](args = (%attr_1, %arg1_1), kwargs = {}) %takes_foo_default : [num_users=1] = call_function[target=torch.ops._TorchScriptTesting.takes_foo.default](args = (%attr_1, %takes_foo_default_1), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %takes_foo_default), kwargs = {}) return (add,) ``` Serialization support will be added in a followup. Note: tokens only affect custom ops that take in ScriptObjects, not ScriptObject methods yet. Differential Revision: [D54639390](https://our.internmc.facebook.com/intern/diff/D54639390) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121424 Approved by: https://github.com/tugsbayasgalan
Author
Committer
Parents
Loading