jax
be6b77cc - Update shard_map(jit) to properly set manual_axes on in_shardings and out_shardings of the nested jit. This avoids a problem where the jit returns {manaual} and then this gets passed to ShardToFull (manual is already considered a full sharding).

Commit
1 year ago
Update shard_map(jit) to properly set manual_axes on in_shardings and out_shardings of the nested jit. This avoids a problem where the jit returns {manaual} and then this gets passed to ShardToFull (manual is already considered a full sharding). PiperOrigin-RevId: 655719254
Author
Committer
Parents
Loading