jax
71f9764e - [JAX] Generate more readable error for failed device deserialization in colocated Python

Commit
314 days ago
[JAX] Generate more readable error for failed device deserialization in colocated Python When deserializing a colocated Python function or input/output sharding, we often need to deserialize a device using a device id. This is done by looking up a CPU device map; this lookup can fail if the device id was referring to a non-CPU device. Unfortunately, we would see a simple error message like `KeyError: np.int64(0)` that does not give a context of the problem. This change adds a slightly more context to the exception so that the error is more actionable. PiperOrigin-RevId: 729172296
Author
Parents
Loading