Add download-jax-rocm-wheels composite action
Create GitHub Actions composite action to download JAX, jaxlib, and ROCm
plugin wheels for CI testing. The action supports two modes:
1. head (default): Downloads JAX and jaxlib from GCS bucket, and ROCm
plugins (PJRT and plugin) from the latest release (including pre-releases)
of ROCm/jax-plugin repository using gh CLI.
2. pypi_latest: Downloads all packages (jaxlib and ROCm plugins) from PyPI.
Features:
- Automatic ROCm version mapping (6.x -> "60", 7.x -> "7") for wheel naming
- Python version handling including free-threaded builds (3.13t, 3.14t)
- Platform-specific filtering (OS, architecture)
- GCS authentication disabled for ROCm runners
- Robust pre-release detection using gh release list with publishedAt sorting
Default configuration:
- ROCm version: 7.2.0
- jaxlib version: head
- GCS URI: gs://general-ml-ci-transient/jax-github-actions/jax/...