jax
080007ab - Ensure values returned by jax.random.truncated_normal() are in range.

Commit
5 years ago
Ensure values returned by jax.random.truncated_normal() are in range. A user observed -inf values being returned by truncated_normal(), which occur if the uniform random value passed to erfinv() is out of range, e.g., due to rounding. Do more of the computation using jax.random.uniform(), which promises correct behavior in the face of rounding. As an added security measure, also clamp the outputs of the function to the open interval.
Author
Committer
Parents
Loading