jax
df8e17a5 - [JAX] Enable Shardy by default in JAX.

Commit
283 days ago
[JAX] Enable Shardy by default in JAX. See [Shardy JAX Migration](https://docs.jax.dev/en/latest/shardy_jax_migration.html) for more information. ## TL;DR ### What’s going on? [Shardy](https://openxla.org/shardy) is a new partitioning system co-developed by GDM Model Scaling (author of [PartIR](https://arxiv.org/abs/2401.11202)) and XLA/CoreML teams (author of [GSPMD](https://arxiv.org/abs//2105.04663)). Shardy aims to provide better usability and control to users, and will gradually replace GSPMD and PartIR. After the migration is complete in March 2026, Shardy will be the only partitioner in JAX. Until then, as a temporary workaround for any problems, Shardy can be disabled (see below). Please file a [JAX issue](https://github.com/jax-ml/jax/issues) if you encounter any problem. ### How do I know if Shardy broke my code? The easiest way to tell if Shardy is responsible for any problems is to disable Shardy and see if the issues go away. You can tell that Shardy is enabled by looking for `Using Shardy for XLA SPMD propagation in the logs`. ### How can I disable Shardy for now? Until March, 2026 it will be possible to temporarily disable Shardy by: * setting the shell environment variable `JAX_USE_SHARDY_PARTITIONER` to something false-like (e.g., 0); * setting the boolean flag `jax_use_shardy_partitioner` to something false-like if your code parses flags with absl; * using this statement in your main file or anywhere before you call `jax.jit`: ``` python import jax jax.config.update('jax_use_shardy_partitioner', False) ``` To debug partitioning with Shardy enabled, you can enable MLIR dumps as follows: ``` --xla_dump_hlo_pass_re=shardy --xla_dump_to=<some_directory> ``` NOTE: Please disable only the specific use cases that are not working as expected if possible, and file a [bug](https://github.com/jax-ml/jax/issues) with a reproducer, so we can resolve it asap and re-enable Shardy. ### JAX export backwards compatibility Enabling Shardy in JAX by default is maintaining the 6 months backwards compatibility guarantee. This means that you will be able to load a model exported with Shardy disabled for at least 6 months after Shardy becomes enabled for your model. That old checkpointed model will run with GSPMD, and only when re-exporting the model will it start running with Shardy. However, if you still encounter an issue with loading an old checkpoint, please contact us or file a [bug](https://github.com/jax-ml/jax/issues). NOTE: exporting a model with Shardy enabled, then loading it with Shardy disabled isn’t supported and will fail. ### How do I prepare for Shardy being enabled in March 2026 permanently? Due to us falling back to GSPMD for any JAX export checkpoint for 6 months, to help find any potential issues, please re-export any models you have with Shardy enabled. Then you can see if it runs fine, or there is any bug we need to fix. PiperOrigin-RevId: 785166956
Author
Parents
Loading