Add batching rule for torch.expand (#40097)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40097
This is (probably) necessary for the vmap frontend API (coming up after
this PR should be the vmap frontend API).
There is some manual handling of sizes in the `expand_batching_rule`.
In particular, when performing expand(Tensor[B0, 3], [2, 3]), where B0
is a batch dimension and Tensor[B0, 3] is a batched tensor with batch
dimension B0, we can't call expand directly on the physical view and
instead first need to perform a view.
It's possible to add said view as a helper function on `VmapPhysicalView` but
after reading through the operator spreadsheet the conclusion was that
no other operator needs the same manual handling.
Test Plan: - `./build/bin/vmap_test`
Differential Revision: D22070657
Pulled By: zou3519
fbshipit-source-id: 911854b078a1a5c7d5934ef2e17b16673ed9d103