Implement batching rules for basic arithmetic ops (#43362)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43362
Batching rules implemented for: addition subtraction division
multiplication.
I refactored the original `mul_batching_rule` into a templated function
so that one can insert arbitrary binary operations into it.
add, sub, rsub, mul, and div all work the same way. However, other
binary operations work slightly differently (I'm still figuring out the
differences and why they're different) so those may need a different
implementation.
Test Plan: - "pytest test/test_vmap.py -v": new tests
Reviewed By: ezyang
Differential Revision: D23252317
Pulled By: zou3519
fbshipit-source-id: 6d36cd837a006a2fd31474469323463c1bd797fc