jax
5f953a9d - remove use of nanobind for _export pass with shardy

Commit
242 days ago
remove use of nanobind for _export pass with shardy right now using these nanobind methods requires going to bytecode and back. This creates extra overhead for these simple operations. Moving these to pure python removes extra lag. lowered_with_shardy and get_mesh are 1 to 1 reproduced in Python. sdy_round_trip_import_shardings is used from xla_sdy_capi with change gain 18% improvement on export of jitted function. PiperOrigin-RevId: 775439256
Author
Parents
Loading