Fix transpose batching rule (#47628)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47628
Pytorch has a special case where scalar_tensor.transpose(0, 0) works and
returns the scalar tensor. If the following happens:
```py
>>> x = torch.randn(B0) # the per-examples are all scalars
>>> vmap(lambda x: x.transpose(0, 0), x)
```
then we replicate this behavior
Test Plan: - new tests
Reviewed By: anjali411
Differential Revision: D24843658
Pulled By: zou3519
fbshipit-source-id: e33834122652473e34a18ca1cecf98e8a3b84bc1