pytorch
769cc8a6 - [MPS] Add type promotion to `torch.addcmul` (#96164)

Commit
1 year ago
[MPS] Add type promotion to `torch.addcmul` (#96164) Fixes crash while running something like `python -c "import torch;x=torch.rand(3, 3, dtype=torch.float16, device='mps');y=x.addcmul(torch.ones(3, device='mps'), torch.ones(3, device='mps'));print(y)"` Modify `castMPSTensor` to become a no-op if cast is not needed Define `common_dtype` as `c10::promoType` between self, tensor1 and tensor2. Cast to any output type. Add mixed-types test to `TestMPS.test_addcmul`, though it does not cover all the permutations Discovered while looking at https://github.com/pytorch/pytorch/issues/96113 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96164 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading