pytorch
beaa5c59 - [MPS] View fixes (#95323)

Commit
2 years ago
[MPS] View fixes (#95323) * [MPS] Fix the uint8 type issue with View ops kernels (#95145) This should fix the problem in Resnet model with image artifacts due to saturation on int8 type and also the incorrect class recognition reported in #86954. Fixes #86954 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95145 Approved by: https://github.com/kulinseth, https://github.com/DenisVieriu97 * [MPS] Fix tensor with non-zero storage offset graph gathering (#91071) Previously, the "can slice" flag in Placeholder constructor in `OperationUtils.mm` is conditioned on whether the numbers of dimensions of base shape and view shape are the same. This doesn't consider the situation that a view tensor could be the base tensor's sliced and then unsqueezed version, resulting in different num of dims. For example, if we want to stack `y_mps` and `x_mps` on the last dim: ``` t_mps = torch.tensor([1, 2, 3, 4], device="mps") x_mps = t_mps[2:] # [3, 4] y_mps = t_mps[:2] # [1, 2] res_mps = torch.stack((y_mps, x_mps), dim=-1) ``` the kernel will unsqueeze both of them on the last dim and then concatenate them, which is equivalent to: ``` res_mps = torch.cat((y_mps.unsqueeze(-1), x_mps.unsqueeze(-1)), dim=-1) ``` `x_mps.unsqueeze(-1)` is an unsqueezed and contiguous tensor with a storage offset, this kind of tensors should be sliceable without cloning its storage. Fixes #87856 Fixes #91065 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91071 Approved by: https://github.com/kulinseth * [MPS] Fix fill_ where input tensor has a storage offset (#95113) Fixes #94390 Apart from fixing the issue above, this PR also fixes a bug that when an input tensor can be sliced, a sliced array view is created. This array view seems to be not writable or have a different storage from the original tensor, causing incorrect results with the in-place `fill`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95113 Approved by: https://github.com/kulinseth * [MPS] Fix view op slicing for 2nd dim in case of 0 offset (#95381) * Fix view op slicing for 2nd dim in case of 0 offset Pull Request resolved: https://github.com/pytorch/pytorch/pull/95381 Approved by: https://github.com/razarmehr --------- Co-authored-by: Ramin Azarmehr <razarmehr@apple.com> Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com> Co-authored-by: Denis Vieriu <104024078+DenisVieriu97@users.noreply.github.com>
Author
Parents
Loading