pytorch
ae62cf7c - [MPS] Revamp copy_to_mps_ implementation (#86956)

Commit
2 years ago
[MPS] Revamp copy_to_mps_ implementation (#86956) Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`. Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call. Modify `copy_to_mps_` function to do the following steps: - Cast `src` tensor to dst data type if needed - Expand `src` tensor to `dst` tensor shape - Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`) - Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides - Do 1d copy for `src` to (potentiall temp) `dst` - Finally do re-striding/copy on MPS if needed Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor Fixes https://github.com/pytorch/pytorch/issues/86954 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86956 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading