Refactor VmapPhysicalView::newLogicalToPhysical (#49482)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49482
Motivation
==========
Batching rules always invoke newLogicalToPhysical at the very end to turn
a physical tensor into a logical BatchedTensor (an example is below):
```
Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
return grad_physical.newLogicalFromPhysical(grad_input);
}
```
However, albanD noted that this function is confusing and ambiguous
because it's unclear which physical tensor is being turned into the logical
(in this case, grad_physical is a VmapPhysicalView, but we're really transforming
grad_input and returning it).
https://github.com/pytorch/pytorch/pull/44505#discussion_r487144018
I didn't want to make too many changes to the batching rule API because
I think we'll change it even more in the future, but this PR attempts to
remove the ambiguity by applying one of the suggestions in
https://github.com/pytorch/pytorch/pull/44505#discussion_r487144018
This PR
=======
The diagnosis of the problem is that we were conflating
"VmapPhysicalView", which maps logical attributes on a Tensor (like
dimension and shape) to physical attributes, with the reverse
physical-to-logical map. This PR creates a new VmapPhysicalToLogicalMap
object that handles the latter.
Instead of calling `grad_physical.newLogicalFromPhysical(grad_input)`,
an author of batching rules should now retrieve the VmapPhysicalToLogicalMap
object and apply it to their physical input. So the above code becomes:
```
grad_physical.getPhysicalToLogicalMap().apply(grad_input)
```
I've also moved VmapPhysicalView::makeLogicalFromPhysicalListInplace
to VmapPhysicalToLogicalMap::applyInplace.
Test Plan
=========
wait for tests
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D25592645
Pulled By: zou3519
fbshipit-source-id: 9c6ede9901ec6b70e5763193064658a8f91e6d48