[autograd.Function] add nice error message for incorrect usage of vmap (#92023)
This PR:
- adds a nice error message if the user doesn't follow the API of the
vmap staticmethod correctly. That is, the user must return two
arguments from the vmap staticmethod API: (outputs, out_dims), and
out_dims must be a PyTree with either the same structure as `outputs`
our be broadcastable to the same structure as `outputs`.
- Fixes an edge case for out_dims=None. out_dims is allowed to be None,
but wrap_outputs_maintaining_identity was treating "None" as "This is
not the vmap case"
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92023
Approved by: https://github.com/soulitzer