Pin `flax` and skip C++ test `SiLUBackward`. (#9660)
Since https://github.com/pytorch/pytorch/pull/162659 was merged again,
we observed that `SiLUBackward` C++ test was crashing with a
segmentation fault #9561. Not only that, but TPU tests started failing
because `flax` 0.12.0 (old: 0.11.2) started pulling a newer `jax` 0.7.2
(old: 0.7.1).
- Old CI build:
[link](https://github.com/pytorch/xla/actions/runs/17931468317/job/51089906800)
- Recent broken CI build:
[link](https://github.com/pytorch/xla/actions/runs/18008717023/job/51550125217?pr=9655)
Therefore, in this PR:
- Pin `flax` to version 0.11.2
- Skip `SiLUBackward` C++ test
Additionally, it also installs `jax` and `libtpu` using the CI
PyTorch/XLA wheels metadata instead of using PyPI wheels metadata. This
should avoid other version compatibilities.