pytorch
dbf96164 - [MPS] Add suport for casting updatesTensor directly in scatter (#91197)

Commit
2 years ago
[MPS] Add suport for casting updatesTensor directly in scatter (#91197) Fixes copies into slices where the input data type is different than the output dtype. This change removes the cast done before scatter, so we don't have to allocate additional memory to perform the casting. Scatter handles the casting directly now. device = "mps" shape = (4, 4) tensor = torch.randint(10, shape, device=device) tensor_before = tensor.clone() res = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) torch.testing.assert_close(tensor, tensor_before) Pull Request resolved: https://github.com/pytorch/pytorch/pull/91197 Approved by: https://github.com/razarmehr
Author
Committer
Parents
Loading