pytorch
76324840 - Add some batched gradient tests (#44494)

Commit
5 years ago
Add some batched gradient tests (#44494) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44494 These tests check (most) operations that are useful for bayesian logistic regression (BLR) models. Said operators are basically those found in the log_prob functions of Distributions objects. This PR is not a general, structured solution for testing batched gradients (see "Alternative solution" for that), but I wanted to test a small subset of operations to confirm that the BLR use case works. There will be follow-up PRs implementing support for some missing operations for the BLR use case. Alternative solution ===================== Ideally, and in the future, I want to autogenerate tests from common_method_invocations and delete all of the manual tests introduced by this PR. However, if we were to do this now, we would need to store the following additional metadata somewhere: - operator name, supports_batched_grad, allow_vmap_fallback_usage We could store that metadata as a separate table from common_method_invocations, or add two columns to common_method_invocations. Either way that seems like a lot of work and the situation will get better once vmap supports batched gradients for all operators (on the fallback path). I am neutral between performing the alternative approach now v.s. just manually writing out some tests for these operations, so I picked the easier approach. Please let me know if you think it would be better to pursue the alternative approach now. Test Plan: - `pytest test/test_vmap.py -v -k "BatchedGrad"` Reviewed By: anjali411 Differential Revision: D23650408 Pulled By: zou3519 fbshipit-source-id: 2f26c7ad4655318a020bdaab5c767cd3956ea5eb
Author
Parents
Loading