optimize MulGradient for common shapes (#19705)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19705
Optimizing for a case when there's a consecutive dims that are not broadcasted followed by another consecutive dims that are broadcasted.
For example, MulGradient(["dC", "A", "B"], ["dA", "dB"], broadcast=True, axis=0) where A.shape == dC.shape == [9508, 80] and B.shape == [80] .
Test Plan:
In SKL T6,
Running mul_gradient_benchmark without this optimization
Operator #0 (dA, MulGradient) 11.9119 ms/iter
After this optimization,
Operator #0 (dA, MulGradient) 0.672759 ms/iter
Need to land D15291800 before to fix the unit test error
Reviewed By: dmudiger
Differential Revision: D15075415
fbshipit-source-id: 0f97be17cf8f1dacbafa34cd637fb8bc1c5e5387