Add broadcast_shapes() function and use it in MultivariateNormal (#43935)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/43837
This adds a `torch.broadcast_shapes()` function similar to Pyro's [broadcast_shape()](https://github.com/pyro-ppl/pyro/blob/7c2c22c10dffda8a33ffbd593cc8d58819959e40/pyro/distributions/util.py#L151) and JAX's [lax.broadcast_shapes()](https://jax.readthedocs.io/en/test-docs/_modules/jax/lax/lax.html). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to `torch.broadcast_tensors()` but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in `torch.distributions`.
- [x] refactor `MultivariateNormal`'s expansion logic to use `torch.broadcast_shapes()`
- [x] add unit tests for `torch.broadcast_shapes()`
- [x] add docs
cc neerajprad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43935
Reviewed By: bdhirsh
Differential Revision: D25275213
Pulled By: neerajprad
fbshipit-source-id: 1011fdd597d0a7a4ef744ebc359bbb3c3be2aadc