[mutable-arrays] make one-level remat work with refs
Functions that only internally use refs are pure. #31389 added support for
those, though it didn't add tests. This PR adds a couple simple tests for
remat-decorated functions that use only internal refs.
With external refs, passed to a remat-decorated function either as explicit
arguments or by closure, the semantics are ambiguous: do we want the effect
only to happen once, or every time the function is re-run? Given that
ambiguity, we plan to let the user be explicit about what behavior they want.
But there's a sensible default behavior: only run the effects the first time.
That is, running this program should print "2.0" and not "3.0":
```python
@jax.remat
def f(y, x_ref):
out = y * x_ref[...]
x_ref[...] += 1
return out
x_ref = jax.new_ref(1.)
jax.grad(f)(1., x_ref)
print(x_ref) # 2.0
```
We capture the initial value of the ref as a residual, and then use it to
initialize a fresh (internal-only) ref used in the rematerializing pass.