jax
c7f3d1c0 - fix breakage from #28318

Commit
349 days ago
fix breakage from #28318 when using a chex.dataclass(mapapble_dataclass=False), we can't tree_map over those instances, even though they might be passed as args to jitted functions. PiperOrigin-RevId: 754084506
Author
Parents
Loading