[TensorExpr] Make KernelSumMultipleAxes much faster (#43905)
Summary:
Reduce input size, skip the dtype conversion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43905
Test Plan: test_tensorexpr --gtest_filter=TensorExprTest.KernelSum*
Reviewed By: ailzhang
Differential Revision: D23433398
Pulled By: asuhan
fbshipit-source-id: 0d95ced3c1382f10595a9e5745bf4bef007cc913