jax
ed4a7bba - #sdy Add JAX backwards compatibility test.

Commit
330 days ago
#sdy Add JAX backwards compatibility test. This tests saving a module with one set of axis names, but loading it with another set of axis names. This does also test the custom calls: - `@Sharding` - `@xla.sdy.GlobalToLocalShape` - `@xla.sdy.LocalToGlobalShape` But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO. PiperOrigin-RevId: 732893432
Author
Parents
Loading