jax
14b2c90f - Allow GSPMDSharding constructor to take in device_list (xc.DeviceList) as input along with `Sequence[jax.Device]`. This prevents extremely slow `tuple(devices) -> DeviceList` conversion in the GSPMDSharding constructor.

Loading