pytorch
aa828bf0 - Support undefined grads in vmap fallback (#46671)

Commit
4 years ago
Support undefined grads in vmap fallback (#46671) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46671 Previously, the vmap fallback would choke whenever it saw an undefined tensor. For each sample in a batch, the fallback runs an operator and then stacks together outputs to get the actual output. Undefined tensors can occur as outputs while computing batched gradients with vmap. This PR updates the vmap fallback to handle undefined tensors which can appear in backward formulas: - if for each sample in a batch the output was undefined, then the vmap fallback returns an undefined tensor - if for each sample in a batch the output is defined, then the vmap fallback stacks together the defined tensors - if for some samples in a batch the output is defined/undefined, then we error out. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D24454909 Pulled By: zou3519 fbshipit-source-id: d225382fd17881f23c9833323b68834cfef351f3
Author
Parents
Loading