pytorch
56b4b441 - Batching rule for torch.mul (#39859)

Commit
4 years ago
Batching rule for torch.mul (#39859) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39859 This PR implements the batching rule for `torch.mul` for the (Tensor, Tensor) overload. NB: ~250 lines of this PR are tests, so please don't be scared away by the line count. It introduces the BroadcastingVmapTransform, which is the VmapTransform one should use for operations that broadcast their inputs. This transform: - permutes all batch dimensions to the front of the tensors - aligns the batch dimensions of the tensors, adding extra 1's where necessary - aligns the non-batch dims of the tensors, adding extra 1's where necessary. Test Plan: - Test BroadcastingVmapTransform in `./build/bin/vmap_test`. - Test mul_batching_rule in `./build/bin/vmap_test`. Differential Revision: D22067337 Pulled By: zou3519 fbshipit-source-id: 5862da8c2b28699b08c7884342a1621581cb2e7f
Author
Parents
Loading