pytorch
f5af97ef - [autograd.Function] add nice error message for incorrect usage of vmap (#92023)

Commit
2 years ago
[autograd.Function] add nice error message for incorrect usage of vmap (#92023) This PR: - adds a nice error message if the user doesn't follow the API of the vmap staticmethod correctly. That is, the user must return two arguments from the vmap staticmethod API: (outputs, out_dims), and out_dims must be a PyTree with either the same structure as `outputs` our be broadcastable to the same structure as `outputs`. - Fixes an edge case for out_dims=None. out_dims is allowed to be None, but wrap_outputs_maintaining_identity was treating "None" as "This is not the vmap case" Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/92023 Approved by: https://github.com/soulitzer
Author
Committer
Parents
Loading