[MPS] Add suport for casting updatesTensor directly in scatter (#91197)
Fixes copies into slices where the input data type is different than the output dtype.
This change removes the cast done before scatter, so we don't have to allocate additional memory to perform the casting. Scatter handles the casting directly now.
device = "mps"
shape = (4, 4)
tensor = torch.randint(10, shape, device=device)
tensor_before = tensor.clone()
res = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
torch.testing.assert_close(tensor, tensor_before)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91197
Approved by: https://github.com/razarmehr