jax
54603e2e - Allow internal mutable array effect to be DCEd

Commit
22 hours ago
Allow internal mutable array effect to be DCEd Why? Previously we were being conservative about when we DCE effects. Code like this: ``` @jax.jit def f(x, y): @jax.jit def g(y): y_ref = jax.new_ref(y) y_ref[...] += 1 return y_ref[...] _ = g(y) return x ``` would *not* DCE the call to `g` despite it being functionally "pure" (its Ref effects are entirely contained within the function and it is the implementation of `lambda y: y + 1`) because the InternalMutableArrayEffect would block DCE. This change makes it so `g(y)` can get DCEd by allowing DCE to get rid of equations that have internal mutable array effects if the outputs are unused. PiperOrigin-RevId: 867721260
Author
Parents
Loading