[JAX] Enable Shardy by default in JAX.
See [Shardy JAX Migration](https://docs.jax.dev/en/latest/shardy_jax_migration.html) for more information.
## TL;DR
### What’s going on?
[Shardy](https://openxla.org/shardy) is a new partitioning system co-developed
by GDM Model Scaling (author of [PartIR](https://arxiv.org/abs/2401.11202)) and
XLA/CoreML teams (author of [GSPMD](https://arxiv.org/abs//2105.04663)). Shardy
aims to provide better usability and control to users, and will gradually
replace GSPMD and PartIR.
After the migration is complete in March 2026, Shardy will be the only
partitioner in JAX.
Until then, as a temporary workaround for any problems, Shardy
can be disabled (see below). Please file a
[JAX issue](https://github.com/jax-ml/jax/issues) if you encounter any problem.
### How do I know if Shardy broke my code?
The easiest way to tell if Shardy is responsible for any problems is to disable
Shardy and see if the issues go away.
You can tell that Shardy is enabled by looking for
`Using Shardy for XLA SPMD propagation in the logs`.
### How can I disable Shardy for now?
Until March, 2026 it will be possible to temporarily disable Shardy by:
* setting the shell environment variable `JAX_USE_SHARDY_PARTITIONER` to
something false-like (e.g., 0);
* setting the boolean flag `jax_use_shardy_partitioner` to something
false-like if your code parses flags with absl;
* using this statement in your main file or anywhere before you call
`jax.jit`:
``` python
import jax
jax.config.update('jax_use_shardy_partitioner', False)
```
To debug partitioning with Shardy enabled, you can enable MLIR dumps as follows:
```
--xla_dump_hlo_pass_re=shardy --xla_dump_to=<some_directory>
```
NOTE: Please disable only the specific use cases that are not working as
expected if possible, and file a [bug](https://github.com/jax-ml/jax/issues)
with a reproducer, so we can resolve it asap and re-enable Shardy.
### JAX export backwards compatibility
Enabling Shardy in JAX by default is maintaining the 6 months backwards
compatibility guarantee. This means that you will be able to load a model
exported with Shardy disabled for at least 6 months after Shardy becomes enabled
for your model. That old checkpointed model will run with GSPMD, and only when
re-exporting the model will it start running with Shardy.
However, if you still encounter an issue with loading an old checkpoint, please
contact us or file a [bug](https://github.com/jax-ml/jax/issues).
NOTE: exporting a model with Shardy enabled, then loading it with Shardy
disabled isn’t supported and will fail.
### How do I prepare for Shardy being enabled in March 2026 permanently?
Due to us falling back to GSPMD for any JAX export checkpoint for 6 months, to
help find any potential issues, please re-export any models you have with Shardy
enabled. Then you can see if it runs fine, or there is any bug we need to fix.
PiperOrigin-RevId: 785166956