Fix `Dirichlet.log_prob()` when x=0 and alpha=1 (#103605)
`Dirichlet.log_prob()` incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$. The Dirichlet PDF is given by:
$$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as `0 * log(0)`.
This PR implements the same algorithm that `scipy.stats.dirichlet` uses to avoid this behavior, namely `xlogy(alpha - 1, x)` instead of `(alpha - 1) * log(x)`. It also adds a test case comparing the pytorch and scipy implementations for this specific case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103605
Approved by: https://github.com/albanD