pytorch
24cc7fe0 - Fix Wishart distribution documentation (#95816)

Commit
1 year ago
Fix Wishart distribution documentation (#95816) This PR fixes the `torch.distributions.wishart.Wishart` example. Running the current example ```python m = Wishart(torch.eye(2), torch.Tensor([2])) m.sample() # Wishart distributed with mean=`df * I` and # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j ``` fails with ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) Untitled-1 in [321](untitled:Untitled-1?line=320) # %% ----> [322](untitled:Untitled-1?line=321) m = Wishart(torch.eye(2), torch.Tensor([2])) [323](untitled:Untitled-1?line=322) m.sample() # Wishart distributed with mean=`df * I` and [324](untitled:Untitled-1?line=323) # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j Untitled-1 in __init__(self, df, covariance_matrix, precision_matrix, scale_tril, validate_args) [83](untitled:Untitled-1?line=82) [84](untitled:Untitled-1?line=83) if param.dim() < 2: ---> [85](untitled:Untitled-1?line=84) raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions") [86](untitled:Untitled-1?line=85) [87](untitled:Untitled-1?line=86) if isinstance(df, Number): ValueError: scale_tril must be at least two-dimensional, with optional leading batch dimensions ``` Is seems that the parameters of `Wishart.__init__()` were re-ordered, but the documentation was not updated. This PR fixes it. Here is the updated behaviour: ```python m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) m.sample() ``` ``` Untitled-1:255: UserWarning: Singular sample detected. tensor([[[6.6366, 0.7796], [0.7796, 0.2136]]]) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/95816 Approved by: https://github.com/ngimel, https://github.com/kit1980
Author
Committer
Parents
Loading