Update PyTorch and XLA pin. (#9668)
This PR updates the following pins:
- PyTorch:
https://github.com/pytorch/pytorch/commit/928ac57c2ab03f9f79376f9995553eea2e6f4ca8
to
https://github.com/pytorch/pytorch/commit/21fec65781bebe867faf209f89bb687ffd236ca4
(v2.9.0-rc5)
- OpenXLA:
https://github.com/openxla/xla/commit/92f7b5952dd585c5be17c9a5caad27407005b513
to
https://github.com/openxla/xla/commit/9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca
- `libtpu`: 0.0.21 to 0.0.24
- JAX (and `jaxlib`): 0.7.1 to 0.8.0
**Key Changes:**
- `@python` was replaced by `@rules_python` at `BUILD` file (ref:
[jax-ml/jax#31709](https://github.com/jax-ml/jax/pull/31709))
- `TF_ATTRIBUTE_NORETURN` was removed in favor of abseil (ref:
[openxla/xla#31699](https://github.com/openxla/xla/pull/31699))
- Replaced include of `xla/pjrt/tfrt_cpu_pjrt_client.h` file by
`xla/pjrt/cpu/cpu_client.h` in `pjrt_registry.cpp`
([openxla/xla#30936](https://github.com/openxla/xla/pull/30936))
- Moved the old `xla/tsl/platform/default/logging.*` to
`torch_xla/csrc/runtime/tsl_platform_logging.*`
- They were removed in
[openxla/xla#29477](https://github.com/openxla/xla/pull/29477)
- Copied them here, temporarily. They should be removed once we update
our error throwing macros.
- Commented out a few macro definitions, avoiding macro re-definitions
**Update (Oct 3):**
- Add an OpenXLA patch for fixing `static_assert(false)` for GCC < 13
([ref](https://gcc.gnu.org/git/?p=gcc.git;a=commit;h=9944ca17c0766623bce260684edc614def7ea761))
- Removed the `flax` pin, since it does not overwrite `jax` anymore
- Removed `TPU*` prefix of `jax.experimental.pallas.tpu` components
(ref: [jax-ml/jax#29115](https://github.com/jax-ml/jax/pull/29115))
---------
Co-authored-by: Bhavya Bahl <bbahl@google.com>