pytorch
12dc410f - Fix nvFuser's where(tensor, python_scalar, tensor) type promotion (#80347)

Commit
2 years ago
Fix nvFuser's where(tensor, python_scalar, tensor) type promotion (#80347) This PR modifies the type promotion logic for nvFuser's `where` function when one of the arguments is a scalar. With the proposed change behavior now matches with ATen's type promotion. The following script fails on master and passes with this PR: ```py import torch import torch._refs from torch._prims.executor import make_traced a = torch.ones(3, 3, dtype=torch.bool, device='cuda') b = torch.randn(3, 3, device='cuda') func = lambda a, b: torch._refs.where(a, 0.0, b) assert make_traced(func)(a, b, executor="nvfuser").dtype == torch.float32 ``` This PR allows to unskip nvFuser tests for `_refs.log_softmax`, it was failing with a dtype mismatch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80347 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading