jax
4d2808c1 - [mutable-arrays] limit implicit ref_swap dtype promotion

Commit
269 days ago
[mutable-arrays] limit implicit ref_swap dtype promotion fixes #27683 In b7715e279, specifically this line: https://github.com/jax-ml/jax/commit/b7715e279dd96938452a3817564d8671bf681543#diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193 we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like: ```python import jax.numpy as jnp from jax._src import core v = core.mutable_array(jnp.array([0, 0, 0])) v[...] += 1.0 print(v) # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32) ``` We can't easily just drop this behavior because it seems many GPU x64 tests depend on it. So in this change we're trying to 1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype; 2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and 3. do an ordinary cast rather than a bitcast. I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g. ```python v = core.mutable_array(jnp.array(0, dtype='bfloat16')) v[...] += 1.0 # don't error! ``` But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment. PiperOrigin-RevId: 745198669
Author
Parents
Loading