pytorch
2af09393 - `masked_scatter` should accept only bool masks (#97999)

Commit
2 years ago
`masked_scatter` should accept only bool masks (#97999) Modify test_torch to check that assert is raised in this case torch.uint8 usage has been deprecated for a few releases, and errors has been raised for other dtypes on CUDA device, but not on CPU. This PR finally restricts mask to just `torch.bool` See https://github.com/pytorch/pytorch/pull/96594 as an example doing it for `torch.masked_fill` Fixes https://github.com/pytorch/pytorch/issues/94634 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97999 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading