try to make at::cat in mm_tree_reduction operate on contig tensors (#18816)
Summary:
Sometimes at::cat gets transposed inputs and goes on a slow path. Also, make jit_premul lstm benchmark add bias to the whole input tensor to avoid separate reduction kernels in the backward pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18816
Differential Revision: D15013576
Pulled By: wanchaol
fbshipit-source-id: bcfa1cf44180b11b05b0f55f034707012f66281a
Author
Natalia Gimelshein