vmap fallback kernel (#41943)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41943
If an operator doesn't have a batching rule implemented then we fallback
to this implementation. The fallback only works on out-of-place operators
that return only tensors with new memory. (e.g., no in-place operators,
no view operations).
The fallback effectively takes all of the BatchedTensors in `stack`,
slices them, and runs `op` on all of the corresponding slices to produce slices
of the outputs. The output slices then get `torch.stack`ed to create the
final returns.
The performance of the fallback is not very good because it introduces
an extra copy from stacking the sliced outputs. Because of this, we prefer
to write batching rules for operators whenever possible.
In the future, I'd like to disable the fallback kernel for random
functions until we have a better random story for vmap. I will probably
add a blocklist of operators to support that.
Test Plan: - `pytest test/test_vmap.py -v`
Reviewed By: ezyang
Differential Revision: D22764103
Pulled By: zou3519
fbshipit-source-id: b235833f7f27e11fb76a8513357ac3ca286a638b