Add batched grad testing to OpInfo (#50818)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50818
This PR does two things:
1. Add batched grad testing to OpInfo
2. Improve the error message from `gradcheck` if batched gradient
computation fails to include suggestions for workarounds.
To add batched grad testing to OpInfo, this PR:
- adds new `check_batched_grad=True` and `check_batched_gradgrad=True`
attributes to OpInfo. These are True by default because we expect most
operators to support batched gradient computation.
- If `check_batched_grad=True`, then `test_fn_grad` invokes gradcheck
with `check_batched_grad=True`.
- If `check_batched_gradgrad=True`, then `test_fn_gradgradgrad` invokes
gradgradcheck with `check_batched_grad=True`.
The improved gradcheck error message looks like the following when an
exception is thrown while computing batched gradients:
https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f
Future
- Sometime in the not-near future, we will separate out "batched grad
testing" from "gradcheck" for the purposes of OpInfo to make the
testing more granular and also so that we can test that the vmap
fallback doesn't get invoked (currently batched gradient testing only
tests that the output values are correct).
Test Plan: - run tests `pytest test/test_ops.py -v -k "Gradients"`
Reviewed By: ejguan
Differential Revision: D25997703
Pulled By: zou3519
fbshipit-source-id: 6d2d444d6348ae6cdc24c32c6c0622bd67b9eb7b