Convert MPS Tensor data using MPSGraph API (#78092)
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.
Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')
In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
device='mps:0', dtype=torch.int32)
In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')
In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False, True, True, False, False, True, True, False,
False, True, True, False, False, True, True, False, False, True,
True, False, False, True, True, False, False, True, True, False,
False, True, True, False, False, True, True, False, False, True],
device='mps:0')
```
After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')
In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
device='mps:0', dtype=torch.int32)
In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10.], device='mps:0')
In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True], device='mps:0')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78092
Approved by: https://github.com/kulinseth, https://github.com/malfet