[dynamo] Re-dispatch `torch.Tensor.new` into `torch.Tensor.new_empty` method. (#121075)
Fix: https://github.com/pytorch/xla/issues/6009
This PR adds another case to `TensorVariable.method_new` special case, where it
re-dispatches `new` into `new_empty`.
Since we are using fake tensors, the `new` call doesn't actually gets to the corresponding
backend (e.g. XLA). So, things like the following might happen:
```python
@torch.compile(backend="openxla")
def foo(x):
new_x = x.new(*x.size())
# new_x.device() == "xla"
# x.device() == "xla:0"
return new_x + x
a = torch.arange(10)
foo(a.to(xm.xla_device()))
```
Resulting in the following error:
```python
Traceback (most recent call last):
...
File "torch/_dynamo/utils.py", line 1654, in get_fake_value
ret_val = wrap_fake_exception(
File "torch/_dynamo/utils.py", line 1190, in wrap_fake_exception
return fn()
File "torch/_dynamo/utils.py", line 1655, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "torch/_dynamo/utils.py", line 1776, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "torch/_dynamo/utils.py", line 1758, in run_node
return node.target(*args, **kwargs)
File "torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "torch/_subclasses/fake_tensor.py", line 885, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "torch/_subclasses/fake_tensor.py", line 1224, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "torch/_subclasses/fake_tensor.py", line 955, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "torch/_subclasses/fake_tensor.py", line 1445, in _dispatch_impl
return self.wrap_meta_outputs_with_default_device_logic(
File "torch/_subclasses/fake_tensor.py", line 1575, in wrap_meta_outputs_with_default_device_logic
return tree_map(wrap, r)
File "torch/utils/_pytree.py", line 900, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "torch/utils/_pytree.py", line 736, in unflatten
leaves = list(leaves)
File "torch/_subclasses/fake_tensor.py", line 1550, in wrap
) = FakeTensor._find_common_device(func, flat_args)
File "torch/_subclasses/fake_tensor.py", line 625, in _find_common_device
merge_devices(arg)
File "torch/_subclasses/fake_tensor.py", line 620, in merge_devices
raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='xla', size=(10,), dtype=torch.int64), FakeTensor(..., device='xla:0', size=(10,), dtype=torch.int64)), **{}):
Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices xla, xla:0
```
Using `new_empty`, instead, fixes this error because it uses the device from the source
tensor, instead of inferring from the current dispatch key set.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121075
Approved by: https://github.com/jansel