Adding details to kl.py (#72845)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/72765.
- [x] Improved `NotImplementedError` verbosity.
- [x] Automate the docstring generation process
## Improved `NotImplementedError` verbosity
### Code
```python
import torch
dist = torch.distributions
torch_normal = dist.Normal(loc=0.0, scale=1.0)
torch_mixture = dist.MixtureSameFamily(
dist.Categorical(torch.ones(5,)
),
dist.Normal(torch.randn(5,), torch.rand(5,)),
)
dist.kl_divergence(torch_normal, torch_mixture)
```
#### Output before this PR
```python
NotImplementedError:
```
#### Output after this PR
```python
NotImplementedError: No KL(p || q) is implemented for p type Normal and q type MixtureSameFamily
```
## Automate the docstring generation process
### Docstring before this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
.. math::
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
Args:
p (Distribution): A :class:`~torch.distributions.Distribution` object.
q (Distribution): A :class:`~torch.distributions.Distribution` object.
Returns:
Tensor: A batch of KL divergences of shape `batch_shape`.
Raises:
NotImplementedError: If the distribution types have not been registered via
:meth:`register_kl`.
```
### Docstring after this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
.. math::
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
Args:
p (Distribution): A :class:`~torch.distributions.Distribution` object.
q (Distribution): A :class:`~torch.distributions.Distribution` object.
Returns:
Tensor: A batch of KL divergences of shape `batch_shape`.
Raises:
NotImplementedError: If the distribution types have not been registered via
:meth:`register_kl`.
KL divergence is currently implemented for the following distribution pairs:
* :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Bernoulli`
* :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Poisson`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Binomial` and :class:`~torch.distributions.Binomial`
* :class:`~torch.distributions.Categorical` and :class:`~torch.distributions.Categorical`
* :class:`~torch.distributions.Cauchy` and :class:`~torch.distributions.Cauchy`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Dirichlet` and :class:`~torch.distributions.Dirichlet`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.ExponentialFamily` and :class:`~torch.distributions.ExponentialFamily`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Geometric` and :class:`~torch.distributions.Geometric`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.HalfNormal` and :class:`~torch.distributions.HalfNormal`
* :class:`~torch.distributions.Independent` and :class:`~torch.distributions.Independent`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Laplace`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
* :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
* :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
* :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Laplace`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.OneHotCategorical` and :class:`~torch.distributions.OneHotCategorical`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Uniform`
* :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Bernoulli`
* :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Binomial`
* :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Poisson`
* :class:`~torch.distributions.TransformedDistribution` and :class:`~torch.distributions.TransformedDistribution`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Beta`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.ContinuousBernoulli`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Exponential`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gamma`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gumbel`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Normal`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Pareto`
* :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Uniform`
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72845
Reviewed By: mikaylagawarecki
Differential Revision: D34344551
Pulled By: soulitzer
fbshipit-source-id: 7a603613a2f56f71138d56399c7c521e2238e8c5
(cherry picked from commit 6b2a51c796cd8a16551d629ca368360eec34faef)