pytorch
df887936 - Fix transpose batching rule (#47628)

Commit
4 years ago
Fix transpose batching rule (#47628) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47628 Pytorch has a special case where scalar_tensor.transpose(0, 0) works and returns the scalar tensor. If the following happens: ```py >>> x = torch.randn(B0) # the per-examples are all scalars >>> vmap(lambda x: x.transpose(0, 0), x) ``` then we replicate this behavior Test Plan: - new tests Reviewed By: anjali411 Differential Revision: D24843658 Pulled By: zou3519 fbshipit-source-id: e33834122652473e34a18ca1cecf98e8a3b84bc1
Author
Parents
Loading