Improve calling backward() and grad() inside vmap error messages (#42876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42876
Previously, the error messages were pretty bad. This PR adds nice
error messages for the following cases:
- user attempts to call .backward() inside vmap for any reason
whatsoever
- user attempts to call autograd.grad(outputs, inputs, grad_outputs),
where outputs or inputs is being vmapped over (so they are
BatchedTensors).
The case we do support is calling autograd.grad(outputs, inputs,
grad_outputs) where `grad_outputs` is being vmapped over. This is the
case for batched gradient support (e.g., user passes in a batched
grad_output).
Test Plan: - new tests: `pytest test/test_vmap.py -v`
Reviewed By: ezyang
Differential Revision: D23059836
Pulled By: zou3519
fbshipit-source-id: 2fd4e3fd93f558e67e2f0941b18f0d00d8ab439f