Support undefined grads in vmap fallback (#46671)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46671
Previously, the vmap fallback would choke whenever it saw an undefined
tensor. For each sample in a batch, the fallback runs an operator
and then stacks together outputs to get the actual output.
Undefined tensors can occur as outputs while computing batched gradients
with vmap.
This PR updates the vmap fallback to handle undefined tensors which can
appear in backward formulas:
- if for each sample in a batch the output was undefined, then the vmap
fallback returns an undefined tensor
- if for each sample in a batch the output is defined, then the vmap
fallback stacks together the defined tensors
- if for some samples in a batch the output is defined/undefined, then
we error out.
Test Plan: - new tests
Reviewed By: ezyang
Differential Revision: D24454909
Pulled By: zou3519
fbshipit-source-id: d225382fd17881f23c9833323b68834cfef351f3