flax
facbc5d7 - [JAX] Replace uses of deprecated `jax.ops.index_update(x, idx, y)` APIs with their up-to-date, more succinct equivalent `x.at[idx].set(y)`.

Commit
4 years ago
[JAX] Replace uses of deprecated `jax.ops.index_update(x, idx, y)` APIs with their up-to-date, more succinct equivalent `x.at[idx].set(y)`. The JAX operators: jax.ops.index_update(x, jax.ops.index[idx], y) jax.ops.index_add(x, jax.ops.index[idx], y) ... have long been deprecated in lieu of their more succinct counterparts: x.at[idx].set(y) x.at[idx].add(y) ... This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX. The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays. PiperOrigin-RevId: 400647852
Author
Committer
Parents
Loading