jax
17d89ad1 - Fix jax.device_put so it doesn't use tree_map for _check_sharding.

Commit
2 years ago
Fix jax.device_put so it doesn't use tree_map for _check_sharding. This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays. PiperOrigin-RevId: 570189648
Author
Committer
Parents
Loading