Disallow aliased mutable array arguments to vmap.
This PR prevents programmers from passing the same mutable array to a
vmapped function more than once, when `config.mutable_array_checks` is
enabled. The following example illustrates why we prevent this behavior.
```python
import jax
import jax.numpy as jnp
def f(m1, m2):
m1[...] = m2[...]
# We pass m to f twice, but with different axes. This aliasing leads to
# odd behavior.
x = jnp.arange(9.0).reshape(3, 3)
m = jax._src.core.mutable_array(x)
jax.vmap(f, in_axes=[0, 1])(m, m)
# The final value of m is equal to the transpose of x.
print(m)
print(jnp.array_equal(m[...], jnp.transpose(x)))
# MutableArray([[0., 3., 6.],
# [1., 4., 7.],
# [2., 5., 8.]], dtype=float32)
# True
# However, if you believe that jax.vmap(f)(batch) should be equal to
# jax.stack([f(x) for x in batch]), then you might expect m to have the
# following value:
x = jnp.arange(9.0).reshape(3, 3)
m = jax._src.core.mutable_array(x)
for i in range(3):
m[i] = m[:,i]
print(m)
# MutableArray([[0., 3., 6.],
# [3., 4., 7.],
# [6., 7., 8.]], dtype=float32)
```
Note that this PR doesn't eliminate aliasing. It is still possible by
closing over a mutable array that is also passed as an argument:
```python
x = jnp.arange(9.0).reshape(3, 3)
m1 = jax._src.core.mutable_array(x)
def f(m2):
return m1[0] + m2[...]
jax.vmap(f)(m1)
```
In the future, we can prevent this behavior too. `jax.jit` already
prevents this behavior.
You can test this PR with the following commands:
```
pre-commit run --all
pytest -n auto tests/mutable_array_test.py
```