pytorch
1c30844e - where() function added as a Tensor method as well (#92849)

Commit
1 year ago
where() function added as a Tensor method as well (#92849) Fixes #88470 I added the "method" keyword in `aten/src/ATen/native/native_functions.yaml` for the function `where` with Scalar Overload. This way, you can now use `Tensor.where()` with a scalar parameter the same way `torch.where()` can. I added a test in `test/test_torch.py` as requested. It uses the `where()` method on a tensor and then checks it has the same results as the `torch.where()` function. The test is roughly the same as the one provided by the author of the issue. PS: this is the second PR I make to resolve this issue, the first one is #92747. I had troubles with commit signatures and is therefore closed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92849 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading