pytorch
a6b75bb0 - [MPS] Fix placeholder case for missing gather graph (#83744)

Commit
2 years ago
[MPS] Fix placeholder case for missing gather graph (#83744) Fixes https://github.com/pytorch/pytorch/issues/82543, https://github.com/pytorch/pytorch/issues/83230 The current Placeholder code relies to find a gather graph in order to make the data contiguous, otherwise we'll try calling into tensor.contiguous() directly, which for slice elements, won't do anything. E.g consider the following basic case where we index a 2 element tensor: ``` tensor_list = torch.tensor([1.2, 1.0], device="mps") for scalar in tensor_list: r_mps = torch.ceil(scalar) r_cpu = torch.ceil(scalar.to("cpu")) self.assertEqual(r_mps.cpu(), r_cpu) ``` The second element 1.0 is a contiguous view tensor (similar to slicing), but it has no gather graph created behind. In the placeholder, we won't be able to find the graph, thus relying on the fallback case where we call _tensor = src.contiguous();. For an already contiguous tensor, this won't do anything, thus we end up creating the NDArray with all the values of the tensor (1.2 and 1.0 instead of just 1.0). Doing clone instead of contiguous will actually perform a blit behind and take into consideration the storage_offset of the view when performing the copy. Similarly, the following basic case is also failing because of this issue: ``` x = torch.tensor([1.0, 0.49], device="mps") print(x) # prints 1.0 and 0.0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83744 Approved by: https://github.com/razarmehr
Author
Committer
Parents
Loading