jax
af501356 - Fix error when swapping a ref with a trivial indexing transform.

Commit
1 year ago
Fix error when swapping a ref with a trivial indexing transform. Without this fix, the added test case fails with: ``` ... jax/_src/state/discharge.py:416: in _swap_discharge_rule z, x_new = _swap_discharge(x, val, idx, tree) jax/_src/state/discharge.py:421: in _swap_discharge return transform_swap_array(x, transforms, val) jax/_src/state/discharge.py:396: in transform_swap_array result_val = lax_slicing.dynamic_update_slice( jax/_src/lax/slicing.py:215: in dynamic_update_slice start_indices = _dynamic_slice_indices(operand, start_indices) ... AttributeError: 'NoneType' object has no attribute 'ndim' ``` from encountering a None when computing the `result_val`.
Author
Committer
Parents
Loading