xla
11590c1c - Update PyTorch and XLA pin. (#9668)

Commit
131 days ago
Update PyTorch and XLA pin. (#9668) This PR updates the following pins: - PyTorch: https://github.com/pytorch/pytorch/commit/928ac57c2ab03f9f79376f9995553eea2e6f4ca8 to https://github.com/pytorch/pytorch/commit/21fec65781bebe867faf209f89bb687ffd236ca4 (v2.9.0-rc5) - OpenXLA: https://github.com/openxla/xla/commit/92f7b5952dd585c5be17c9a5caad27407005b513 to https://github.com/openxla/xla/commit/9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca - `libtpu`: 0.0.21 to 0.0.24 - JAX (and `jaxlib`): 0.7.1 to 0.8.0 **Key Changes:** - `@python` was replaced by `@rules_python` at `BUILD` file (ref: [jax-ml/jax#31709](https://github.com/jax-ml/jax/pull/31709)) - `TF_ATTRIBUTE_NORETURN` was removed in favor of abseil (ref: [openxla/xla#31699](https://github.com/openxla/xla/pull/31699)) - Replaced include of `xla/pjrt/tfrt_cpu_pjrt_client.h` file by `xla/pjrt/cpu/cpu_client.h` in `pjrt_registry.cpp` ([openxla/xla#30936](https://github.com/openxla/xla/pull/30936)) - Moved the old `xla/tsl/platform/default/logging.*` to `torch_xla/csrc/runtime/tsl_platform_logging.*` - They were removed in [openxla/xla#29477](https://github.com/openxla/xla/pull/29477) - Copied them here, temporarily. They should be removed once we update our error throwing macros. - Commented out a few macro definitions, avoiding macro re-definitions **Update (Oct 3):** - Add an OpenXLA patch for fixing `static_assert(false)` for GCC < 13 ([ref](https://gcc.gnu.org/git/?p=gcc.git;a=commit;h=9944ca17c0766623bce260684edc614def7ea761)) - Removed the `flax` pin, since it does not overwrite `jax` anymore - Removed `TPU*` prefix of `jax.experimental.pallas.tpu` components (ref: [jax-ml/jax#29115](https://github.com/jax-ml/jax/pull/29115)) --------- Co-authored-by: Bhavya Bahl <bbahl@google.com>
Author
Parents
Loading