jax
3830b172 - More numerically stable `jax.experimental.jet._integer_pow_taylor`.

Commit
1 year ago
More numerically stable `jax.experimental.jet._integer_pow_taylor`. The previous version of `jax.experimental.jet._integer_pow_taylor` became numerically unstable when we evaluated the derivative jet of `x**n` around `x = 0` for integer `n > 2`. This was especially true in the higher order derivatives in `float32`. Beforehand there was a complicated expression for approximating the jet of `x**n`. In this CL we replace integer powers with explicit multiplication, e.g. `x**2` becomes `x*x`. For higher powers of `x`, we use the exponentiation by squaring trick, which evaluates expressions such as `x**4` as `(x * x) * (x * x)`, which uses log(n) multiplies instead of `n-1`. PiperOrigin-RevId: 617205359
Author
Committer
Parents
Loading