pytorch
d7dc1c2f - Support zero dimensions in softmax decompositions (#91322)

Commit
2 years ago
Support zero dimensions in softmax decompositions (#91322) The eager implementation of softmax supports computation along zero dimensions, but many of the other implementations did not, including: * decompositions & refs (this was causing dynamo failures) * forward AD for logsumexp * MPS log_softmax_backward This PR handles the `input.numel() == 0` cases separately to avoid running `amax()`, which fails for zero dimensions, and updates opinfos. example of "computation along zero dimensions": ```python # example of where import torch t = torch.rand((4, 0, 0)) print("~") print(torch.nn.functional.softmax(t, dim=-1)) # this passes print("~") torch._refs.softmax(t, dim=-1) # this fails print("~") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91322 Approved by: https://github.com/lezcano
Author
Committer
Parents
Loading