flax
893a6602 - revise axes_scan to flatten argument pytrees only once

Commit
215 days ago
revise axes_scan to flatten argument pytrees only once A user has a custom pytree node with the unusual behavior that it introduces new arrays when flattening. That is, it's as if we had: ```python # a custom object with two leaf arrays custom_tree_object = SomeObject(jax_arrray1, jax_array2) # convert leaves to ShapedArrays custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object) # flatten, should only see ShapedArrays, right? leaves, treedef = jax.tree.flatten(custom_tree_object2) print(leaves) # [ShapedArray(...), ShapedArray(...), np.array(...)] ``` This change makes the `flax.nn.scan` function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues. I don't think we want to support this in general, but this fix seemed like the most expedient way to roll fowrard https://github.com/jax-ml/jax/pull/29273 PiperOrigin-RevId: 768175118
Author
Committer
Parents
Loading