pytorch
a52bfe2c - Convert MPS Tensor data using MPSGraph API (#78092)

Commit
2 years ago
Convert MPS Tensor data using MPSGraph API (#78092) Fixes #78091 If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~. Before: ```python In [5]: pt.full((40,), -10.3, device="mps") Out[5]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int() Out[6]: tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883, -1054552883], device='mps:0', dtype=torch.int32) In [7]: pt.full((40,), -10.3, device="mps").int().float() Out[7]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [8]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[8]: tensor([ True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, True, False, False, True], device='mps:0') ``` After: ```python In [3]: pt.full((40,), -10.3, device="mps") Out[3]: tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0') In [4]: pt.full((40,), -10.3, device="mps").int() Out[4]: tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10], device='mps:0', dtype=torch.int32) In [5]: pt.full((40,), -10.3, device="mps").int().float() Out[5]: tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.], device='mps:0') In [6]: pt.full((40,), -10.3, device="mps").int().float().bool() Out[6]: tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True], device='mps:0') ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78092 Approved by: https://github.com/kulinseth, https://github.com/malfet
Author
Committer
Parents
Loading