vmap: added fallback for in-place operators (#46191)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46191
This PR adds a fallback for in-place operators to vmap. We define an
in-place operator to be an operator that operators in-place on its first
argument and returns the first argument.
The "iteration over batch" logic is mostly copied from the out-of-place
vmap fallback. I wanted to try to not copy this but the iteration logic
is pretty entangled with the rest of the logic; one alternative was to
use if/else statements inside batchedTensorForLoopFallback but then
there are ~3-4 different sites where we would need that.
When in-place operations are not possible
=========================================
Sometimes, an in-place operation inside of vmap is not possible. For
example, `vmap(Tensor.add_, (None, 0))(torch.rand(3), torch.rand(B0, 3))`
is not possible because the tensor being written to in-place has size
[3] and the other tensor has size [B0, 3].
We detect if this is the case and error out inside the in-place
fallback.
Test Plan
=========
Added some new tests to `test_vmap.py`.
Test Plan: Imported from OSS
Reviewed By: malfet
Differential Revision: D24335240
Pulled By: zou3519
fbshipit-source-id: 1f60346059040dc226f0aeb80a64d9458208fd3e