Vectorize LowerCholeskyTransform (#24131)
Summary:
Removes older `torch.stack`-based logic in favor of `torch.diagonal()` and `torch.diag_embed()`.
I see 100x speedup in my application, where my batched matrix has shape `(800, 32 ,32)`.
```py
import torch
from torch.distributions import constraints, transform_to
x = torch.randn(800, 32, 32, requires_grad=True)
# Before this PR:
%%timeit
transform_to(constraints.lower_cholesky)(x).sum().backward()
# 579 ms ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# After this PR:
%%timeit
transform_to(constraints.lower_cholesky)(x).sum().backward()
# 4.5 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24131
Differential Revision: D16764035
Pulled By: ezyang
fbshipit-source-id: 170cdb0d924cdc94cd5ad3b75d1427404718d437