jax
8370f433 - Catch RepError in shard_map_transpose so that we raise a better error message. This is necessary because users can now annotate vma's at their level in custom_vjps.

Commit
206 days ago
Catch RepError in shard_map_transpose so that we raise a better error message. This is necessary because users can now annotate vma's at their level in custom_vjps. Before: ``` jax._src.shard_map._RepError: [<jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, <jax._src.shard_map.NoFail object at 0x30173c4f2f30>, frozenset({'q_seq', 'heads'}), frozenset({'q_seq', 'heads'})] ``` After: ``` ValueError: shard_map applied to the function 'transpose(f)' was given out_specs which require replication which can't be statically inferred given the mesh: The mesh given has shape (2, 2) with corresponding axis names ('heads', 'q_seq'). * out_specs[16] is PartitionSpec('heads', None) which implies that the corresponding output value is replicated across mesh axis 'q_seq', but could not infer replication over any axes * out_specs[17] is PartitionSpec('heads', None) which implies that the corresponding output value is replicated across mesh axis 'q_seq', but could not infer replication over any axes Check if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_vma=False argument to `jax.shard_map`. ``` PiperOrigin-RevId: 778977467
References
Author
Parents
Loading