pytorch
2b9d9bcb - Deprecate non-bool masks in masked_fill (#96594)

Commit
1 year ago
Deprecate non-bool masks in masked_fill (#96594) __What?__ Per discussion at #94634, deprecate `masked_fill` with non-bool masks. Deprecation warnings were previously added by #22261, but not for Apple MPS. I can revert the MPS changes if deprecation warnings are wanted first tho. See also #96112. Fixes #85063 and #89320. __Further Development?__ - Fixed the mask dtype checking for the cuda dispatch for `masked_fill` in `aten/src/ATen/native/cuda/Indexing.cu` Pull Request resolved: https://github.com/pytorch/pytorch/pull/96594 Approved by: https://github.com/malfet, https://github.com/ngimel
Author
Committer
Parents
Loading