pytorch
0d5bc541 - Fix interpretation torch -> torch._refs in case of nested torch calls under TorchRefsMode (#80135)

Commit
2 years ago
Fix interpretation torch -> torch._refs in case of nested torch calls under TorchRefsMode (#80135) torch calls inside `TorchRefsMode.__torch_function__` dispatch should be interpreted as refs calls under `TorchRefsMode`. Fixes https://github.com/pytorch/pytorch/issues/80079. In addition, this PR enables two more tests for the nvFuser executor. For example here's the FX trace of `torch._refs.nn.functional.layer_norm` before the proposed change (note the mix of `aten` and `prims`): ```py opcode name target args kwargs ------------- ---------------------- -------------------------- -------------------------------- ----------------- placeholder a_1 a_1 () {} call_function convert_element_type prims.convert_element_type (a_1, torch.float32) {} call_function var prims.var (convert_element_type, [0, 1]) {'correction': 0} call_function broadcast_in_dim prims.broadcast_in_dim (var, [1, 1], []) {} call_function convert_element_type_1 prims.convert_element_type (a_1, torch.float32) {} call_function sum_1 prims.sum (convert_element_type_1, [0, 1]) {} call_function broadcast_in_dim_1 prims.broadcast_in_dim (sum_1, [1, 1], []) {} call_function div prims.div (broadcast_in_dim_1, 9.0) {} call_function add aten.add (broadcast_in_dim, 1e-05) {} call_function rsqrt aten.rsqrt (add,) {} call_function sub aten.sub (a_1, div) {} call_function mul aten.mul (sub, rsqrt) {} call_function convert_element_type_2 prims.convert_element_type (mul, torch.float32) {} output output output (convert_element_type_2,) {} ``` And with this PR: ```py opcode name target args kwargs ------------- ---------------------- -------------------------- -------------------------------- ----------------- placeholder a_1 a_1 () {} call_function convert_element_type prims.convert_element_type (a_1, torch.float32) {} call_function var prims.var (convert_element_type, [0, 1]) {'correction': 0} call_function broadcast_in_dim prims.broadcast_in_dim (var, [1, 1], []) {} call_function convert_element_type_1 prims.convert_element_type (a_1, torch.float32) {} call_function sum_1 prims.sum (convert_element_type_1, [0, 1]) {} call_function broadcast_in_dim_1 prims.broadcast_in_dim (sum_1, [1, 1], []) {} call_function div prims.div (broadcast_in_dim_1, 9.0) {} call_function add prims.add (broadcast_in_dim, 1e-05) {} call_function rsqrt prims.rsqrt (add,) {} call_function broadcast_in_dim_2 prims.broadcast_in_dim (div, [3, 3], [0, 1]) {} call_function sub prims.sub (a_1, broadcast_in_dim_2) {} call_function broadcast_in_dim_3 prims.broadcast_in_dim (rsqrt, [3, 3], [0, 1]) {} call_function mul prims.mul (sub, broadcast_in_dim_3) {} call_function convert_element_type_2 prims.convert_element_type (mul, torch.float32) {} output output output (convert_element_type_2,) {} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/80135 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading