Add some batched gradient tests (#44494)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44494
These tests check (most) operations that are useful for bayesian logistic
regression (BLR) models. Said operators are basically those found in the
log_prob functions of Distributions objects. This PR is not a general,
structured solution for testing batched gradients (see "Alternative
solution" for that), but I wanted to test a small subset of operations
to confirm that the BLR use case works.
There will be follow-up PRs implementing support for some missing
operations for the BLR use case.
Alternative solution
=====================
Ideally, and in the future, I want to autogenerate tests from
common_method_invocations and delete all of the manual tests
introduced by this PR. However, if we were to do this now,
we would need to store the following additional metadata somewhere:
- operator name, supports_batched_grad, allow_vmap_fallback_usage
We could store that metadata as a separate table from
common_method_invocations, or add two columns to
common_method_invocations. Either way that seems like a lot of work and
the situation will get better once vmap supports batched gradients for
all operators (on the fallback path).
I am neutral between performing the alternative approach now v.s. just
manually writing out some tests for these operations, so I picked the
easier approach. Please let me know if you think it would be better to
pursue the alternative approach now.
Test Plan: - `pytest test/test_vmap.py -v -k "BatchedGrad"`
Reviewed By: anjali411
Differential Revision: D23650408
Pulled By: zou3519
fbshipit-source-id: 2f26c7ad4655318a020bdaab5c767cd3956ea5eb