`matrix_exp`: Make sure `_compute_linear_combinations` result preserves dim of the input. (#81330)
Fixes https://github.com/pytorch/pytorch/issues/80948.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81330
Approved by: https://github.com/Lezcano, https://github.com/mruberry