vmap fallback: gracefully error out when vmap over dim of size 0 (#46846)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46846
Previously, this would crash with a floating point error. If the user vmaps
over a dimension of size 0, ideally we would return a tensor with a
batch dim of size 0 and the correct output shape. However, this isn't
possible without a shape-checking API. This PR changes the vmap fallback
to error out gracefully if it sees vmap occuring over a dimension of
size 0.
If we want to support vmapping over dimension of size 0 for a specific
op, then the guidance is to implement a batching rule for that op that
handles 0-sized dims.
Test Plan: - new test
Reviewed By: ezyang
Differential Revision: D24539315
Pulled By: zou3519
fbshipit-source-id: a19c049b46512d77c084cfee145720de8971f658