jax
ddc67f06 - [mutable-arrays] make one-level remat work with refs

Commit
206 days ago
[mutable-arrays] make one-level remat work with refs Functions that only internally use refs are pure. #31389 added support for those, though it didn't add tests. This PR adds a couple simple tests for remat-decorated functions that use only internal refs. With external refs, passed to a remat-decorated function either as explicit arguments or by closure, the semantics are ambiguous: do we want the effect only to happen once, or every time the function is re-run? Given that ambiguity, we plan to let the user be explicit about what behavior they want. But there's a sensible default behavior: only run the effects the first time. That is, running this program should print "2.0" and not "3.0": ```python @jax.remat def f(y, x_ref): out = y * x_ref[...] x_ref[...] += 1 return out x_ref = jax.new_ref(1.) jax.grad(f)(1., x_ref) print(x_ref) # 2.0 ``` We capture the initial value of the ref as a residual, and then use it to initialize a fresh (internal-only) ref used in the rematerializing pass.
Author
Committer
Parents
Loading