Implement `MultiBatchVmapTransform::logicalToPhysical(TensorList)` (#41942)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41942
This function:
- permutes all batch dims to the front of the tensors
- aligns all the batch dims to the collective levels of all the tensors
- expands all of the batch dims such that they are present in each of
the result tensors
This function is useful for the next diff up on the stack (which is
implementing a fallback kernel for BatchedTensor). It's also useful in
general for implementing batching rules on operators that take in
multiple batch dimensions at the front of each tensor (but we don't have
too many of those in PyTorch).
Test Plan: - `./build/bin/vmap_test`
Reviewed By: ezyang
Differential Revision: D22764104
Pulled By: zou3519
fbshipit-source-id: d42cc8824a1bcf258687de164b7853af52852f53