xla
37067917 - [Pallas] Support programmatically extracting the payload (#6696)

Commit
1 year ago
[Pallas] Support programmatically extracting the payload (#6696) Summary: This pull request introduces _extract_backend_config to extrac the Mosaic payload from the custom call programmatically such that we don't need to copy & paste. In order to run the test, JAX dependencies are added to the CI. The JAX version is using the nightly on the same day as our libtpu for the best compatibility. However, we need to figure out a way to update that automatically when we are updating our open-xla pin. Test Plan: python test/test_operations.py -v -k test_tpu_custom_call_pallas_extract_add_payload
Author
Parents
Loading