pytorch
b7cb4eae - Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)

Commit
2 years ago
Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560) On functorch, we started seeing [embedding forward mode fail](https://github.com/pytorch/functorch/pull/816). From looking at it, we figured out that recently [embedding got forward mode support enabled](https://github.com/pytorch/pytorch/commit/369d9f4137a8bfc20e6a4e1d6ab35eeae4e9b345) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked. What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call Pull Request resolved: https://github.com/pytorch/pytorch/pull/78560 Approved by: https://github.com/soulitzer, https://github.com/albanD
Author
samdow
Committer
Parents
Loading