jax
72e5ca95 - [JAX] Fix a small bug if shardings is tuple.

Commit
1 year ago
[JAX] Fix a small bug if shardings is tuple. # Details `jax.tree.map` requests all its arguments to have the same data type. From ```[None] * len(tensorstore_specs) if global_shapes is None else global_shapes```, The data type is already decided to be a list. So if we pass `sharding` or `tspecs` as a tuple, it will fail. Here we add an explicit conversion to a list for sharding and tspecs. PiperOrigin-RevId: 707576866
Author
Parents
Loading