Refactor gradcheck (#53857)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53857
This PR basically just factors a lot of the logic out from the main gradcheck function into their own individual functions. It aims to avoid any behavior change (but we may not have enough tests to actually verify this). Refactorings that lead to any behavior chang are done in the next PR in this stack.
The rationale for this change is 1) to make the main gradcheck function cleaner to read, and 2) also allow us to reuse the same pieces when we add the fast gradcheck.
Maybe this PR is also a good place to add some tests for gradcheck, i.e., make sure gradcheck fails when it should fail, as to make sure that we are indeed not changing any logic. This will also help us make sure our fast_gradcheck does all the necessary checks:
So far existing tests are:
- test_gradcheck_fail_when_no_differentiable_outputs_and_num_grad_not_zero` (test_autograd)
- test_gradcheck_single_input (test_autograd)
- test_gradcheck_sparse_input (test_autograd)
- test_gradcheck_nondeterministic (test_autograd)
- test_gradcheck (test_overrides)
Full coverage would potentially require adding the following missing tests (for each test for both raise_exception=True/False) - Methodology for getting the list below is that for every type of error message we spit out, we make sure we can hit it:
- complex:
- when numerical != analytical when tested with imag grad_out
- check_inputs
- ~when inputs are not dense, but check_sparse_nnz is false~
- ~when none of the inputs require grad~
- ~(warning) when inputs are not double precision~
- ~when layout is not mkldnn(aka has strides) and input has a dimension with stride 0.~
- check_no_differentiable_outputs:
- ~when none of the outputs are differentiable, but numerical gradient is not zero~
- check_outputs:
- ~when sparse outputs (always raise)~
- ~when mkldnn outputs (always raise)~
- test_batched_grad
- ~when encounter runtime error while computing batched grad (print big message)~
- when not allclose (print out big message)
- test_backward_mul_by_grad_output
- ~when layout of grad_input is not the same as input~
- ~when grad_input is sparse and has incorrect sparse_dim/dense_dim~
- ~when backward not multiplied by grad_output (sparse/non-sparse case)~
- when grad is incorrect type/size
- test_undefined_grad
- ~when encounter runtime error while running backward~
- when we complete backward but grad inputs (the output of .grad()) is not none
- check_analytical_jacobian_attributes (for both complex/non complex)
- when grad input is incorrect dtype/size
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D27201571
Pulled By: soulitzer
fbshipit-source-id: 86670a91e65740d57dd6ada7c6b4512786d15962