jax
a8246ea6 - Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.

Commit
1 year ago
Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs. For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case. In a future release of JAX, this behavior will become an error. PiperOrigin-RevId: 641690427
Author
Committer
Parents
Loading