pytorch
16691516 - Add batched grad testing to OpInfo (#50818)

Commit
3 years ago
Add batched grad testing to OpInfo (#50818) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50818 This PR does two things: 1. Add batched grad testing to OpInfo 2. Improve the error message from `gradcheck` if batched gradient computation fails to include suggestions for workarounds. To add batched grad testing to OpInfo, this PR: - adds new `check_batched_grad=True` and `check_batched_gradgrad=True` attributes to OpInfo. These are True by default because we expect most operators to support batched gradient computation. - If `check_batched_grad=True`, then `test_fn_grad` invokes gradcheck with `check_batched_grad=True`. - If `check_batched_gradgrad=True`, then `test_fn_gradgradgrad` invokes gradgradcheck with `check_batched_grad=True`. The improved gradcheck error message looks like the following when an exception is thrown while computing batched gradients: https://gist.github.com/zou3519/5a0f46f908ba036259ca5e3752fd642f Future - Sometime in the not-near future, we will separate out "batched grad testing" from "gradcheck" for the purposes of OpInfo to make the testing more granular and also so that we can test that the vmap fallback doesn't get invoked (currently batched gradient testing only tests that the output values are correct). Test Plan: - run tests `pytest test/test_ops.py -v -k "Gradients"` Reviewed By: ejguan Differential Revision: D25997703 Pulled By: zou3519 fbshipit-source-id: 6d2d444d6348ae6cdc24c32c6c0622bd67b9eb7b
Author
Parents
Loading