jax
21b52fb5 - Update ml_dtypes to 0.5.1 to align with JAX and TensorFlow

Commit
194 days ago
Update ml_dtypes to 0.5.1 to align with JAX and TensorFlow Several "new" python types in XLA types.h use `std::optional<nanobind::object>`. JAX and TF already switched to ml-dtypes 0.5.0+. Now we can remove `std::optional` usage in XLA `types.h` for these types. Because of XLA GitHib CI workflows and JAX GitHub CI we should do XLA types.h refactoring using the following approach. xla/python/types.h uses JAX_IFRT_VERSION_NUMBER to determine how new types are defined: - If `JAX_IFRT_VERSION_NUMBER < 11`, the new types are declared in the old way as `std::optional<nanobind::object>` - If `JAX_IFRT_VERSION_NUMBER >= 11`, the new types are declared directly as `nanobind::object` (aligning with the convention used for other types) Next steps: 1. Update JAX_IFRT_VERSION_NUMBER version from 10 to 11. (After JAX github has this change merged) 2. Once JAX updates XLA_COMMIT_ID to newer version (in 1-2 days) I will open cleanup change to remove `#if JAX_IFRT_VERSION_NUMBER >= 11` statements from xla `types.h` and jaxlib `py_values.cc` PiperOrigin-RevId: 775507036
Author
Parents
Loading