pytorch
86ae14de - [MPS] Fix MPSGraph casting issue to MPSDataTypeBool in masked_fill op (#94263)

Commit
1 year ago
[MPS] Fix MPSGraph casting issue to MPSDataTypeBool in masked_fill op (#94263) Fixes TestConsistency masked_fill for bool data type. Casting a tensor > 1 to MPSDataTypeBool will result in 0 instead of 1. This change manually casts the scalar to a value of 0 or 1 when casting a non-boolean tensor to a boolean tensor: ``` (inputDataType == MPSDataTypeBool) ? !!value.to<double>() : value.to<double>() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94263 Approved by: https://github.com/razarmehr
Author
Committer
Parents
Loading