fixed compilations on xla tensor print (#71147)
Summary:
Fixes multiple compilation on xla tensor print. Please check the conversation here: https://github.com/pytorch/xla/pull/3253
This is done to avoid compilations during tensor printing. Torch performs some tensor operations like slicing to make the tensor readable. These operations result in compilations. Hence to avoid the compilations, copying the tensor to cpu before printing.
example:
```
dev = xm.xla_device()
def test_linear(input_shape=(8, 1024)):
import pdb
pdb.set_trace()
linear = torch.nn.Linear(in_features=1024, out_features=4096, bias=True).to(dev)
inp = torch.randn(*input_shape).to(dev)
output = linear(inp)
xm.mark_step()
return output
```
Returning from this function would have resulted in 63 compiles, since PDB prints the value of the return output. In this case it is a xla tensor.
Now with the current change, there is no compilation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71147
Reviewed By: shunting314
Differential Revision: D33795177
Pulled By: wconstab
fbshipit-source-id: 74b53d9a1cb7ef67f9d8b0a32064f3896be449b5
(cherry picked from commit a9e0687fc5c9981fb55ea4dc406c283c80fa20c9)