jax
5f3e9502 - Disallow aliased mutable array arguments to vmap.

Commit
228 days ago
Disallow aliased mutable array arguments to vmap. This PR prevents programmers from passing the same mutable array to a vmapped function more than once, when `config.mutable_array_checks` is enabled. The following example illustrates why we prevent this behavior. ```python import jax import jax.numpy as jnp def f(m1, m2): m1[...] = m2[...] # We pass m to f twice, but with different axes. This aliasing leads to # odd behavior. x = jnp.arange(9.0).reshape(3, 3) m = jax._src.core.mutable_array(x) jax.vmap(f, in_axes=[0, 1])(m, m) # The final value of m is equal to the transpose of x. print(m) print(jnp.array_equal(m[...], jnp.transpose(x))) # MutableArray([[0., 3., 6.], # [1., 4., 7.], # [2., 5., 8.]], dtype=float32) # True # However, if you believe that jax.vmap(f)(batch) should be equal to # jax.stack([f(x) for x in batch]), then you might expect m to have the # following value: x = jnp.arange(9.0).reshape(3, 3) m = jax._src.core.mutable_array(x) for i in range(3): m[i] = m[:,i] print(m) # MutableArray([[0., 3., 6.], # [3., 4., 7.], # [6., 7., 8.]], dtype=float32) ``` Note that this PR doesn't eliminate aliasing. It is still possible by closing over a mutable array that is also passed as an argument: ```python x = jnp.arange(9.0).reshape(3, 3) m1 = jax._src.core.mutable_array(x) def f(m2): return m1[0] + m2[...] jax.vmap(f)(m1) ``` In the future, we can prevent this behavior too. `jax.jit` already prevents this behavior. You can test this PR with the following commands: ``` pre-commit run --all pytest -n auto tests/mutable_array_test.py ```
Author
Committer
Parents
Loading