Batching rule for torch.mul (#39859)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39859
This PR implements the batching rule for `torch.mul` for the (Tensor,
Tensor) overload.
NB: ~250 lines of this PR are tests, so please don't be scared away by
the line count.
It introduces the BroadcastingVmapTransform, which is the VmapTransform
one should use for operations that broadcast their inputs. This
transform:
- permutes all batch dimensions to the front of the tensors
- aligns the batch dimensions of the tensors, adding extra 1's where
necessary
- aligns the non-batch dims of the tensors, adding extra 1's where
necessary.
Test Plan:
- Test BroadcastingVmapTransform in `./build/bin/vmap_test`.
- Test mul_batching_rule in `./build/bin/vmap_test`.
Differential Revision: D22067337
Pulled By: zou3519
fbshipit-source-id: 5862da8c2b28699b08c7884342a1621581cb2e7f