jax
96c01299 - Fix false positive `debug_nans` error caused by NaNs that are properly handled in `jax.scipy.stats.gamma`

Commit
1 year ago
Fix false positive `debug_nans` error caused by NaNs that are properly handled in `jax.scipy.stats.gamma` As reported in https://github.com/jax-ml/jax/issues/24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs. Fixes https://github.com/jax-ml/jax/issues/24939 PiperOrigin-RevId: 698833589
Author
dfm dfm
Parents
Loading