pytorch
027c0d7f - fixed compilations on xla tensor print (#71147)

Commit
2 years ago
fixed compilations on xla tensor print (#71147) Summary: Fixes multiple compilation on xla tensor print. Please check the conversation here: https://github.com/pytorch/xla/pull/3253 This is done to avoid compilations during tensor printing. Torch performs some tensor operations like slicing to make the tensor readable. These operations result in compilations. Hence to avoid the compilations, copying the tensor to cpu before printing. example: ``` dev = xm.xla_device() def test_linear(input_shape=(8, 1024)): import pdb pdb.set_trace() linear = torch.nn.Linear(in_features=1024, out_features=4096, bias=True).to(dev) inp = torch.randn(*input_shape).to(dev) output = linear(inp) xm.mark_step() return output ``` Returning from this function would have resulted in 63 compiles, since PDB prints the value of the return output. In this case it is a xla tensor. Now with the current change, there is no compilation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/71147 Reviewed By: shunting314 Differential Revision: D33795177 Pulled By: wconstab fbshipit-source-id: 74b53d9a1cb7ef67f9d8b0a32064f3896be449b5 (cherry picked from commit a9e0687fc5c9981fb55ea4dc406c283c80fa20c9)
Author
Committer
Parents
Loading