pytorch
98a9235d - Fix prelu ref when a.ndim < 2 (#89809)

Commit
2 years ago
Fix prelu ref when a.ndim < 2 (#89809) Fixes https://github.com/pytorch/pytorch/issues/89560 Previously the test case for "input is 1-D or scalar + weight is not scalar" did not exist; adding it introduced some failures: - forward AD (fixed in this PR) - vmap (filed https://github.com/pytorch/pytorch/issues/89895) - ref/meta (fixed this PR, though this also regresses nvFuser support) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89809 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading