pytorch
b7f35e41 - [MPS] Fix index_add with non-f32 inputs (#88542)

Commit
3 years ago
[MPS] Fix index_add with non-f32 inputs (#88542) The `multiplicationWithPrimaryTensor` and/or `scatterWithDataTensor` api has issues with handling two f16 tensor inputs, resulting in zeros outputs. With int16 or int64 inputs, there are issues as well. This PR conditionally casts inputs to f32 if they're not and then casts the output back to the source's datatype. Fixes #82645. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88542 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading