jax
ae792499 - Improve error message when collective APIs are called without a shard_map

Commit
216 days ago
Improve error message when collective APIs are called without a shard_map Before: `unbound axis name: x` After: `Found an unbound axis name: x. To fix this, please call psum under jax.shard_map` PiperOrigin-RevId: 778632500
Author
Parents
Loading