PR #33714: always apply stop_gradient_p, even to exact dtypes
Imported from GitHub PR https://github.com/jax-ml/jax/pull/33714
fixes #33689
Copybara import of the project:
--
25ffbc808b5e05da80e6916e80b4efcd33137874 by Matthew Johnson <mattjj@google.com>:
always apply stop_gradient_p, even to exact dtypes
fixes #33689
Merging this change closes #33714
COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/33714 from mattjj:issue33689 25ffbc808b5e05da80e6916e80b4efcd33137874
PiperOrigin-RevId: 840301807