Register shutdown code at import to hopefully get registered before any other atexit callbacks.
`atexit` callbacks are called in a LIFO order, meaning that since Jax currently registers its callback at runtime rather than import time, it gets called before any `atexit` callbacks registered at import time.
PiperOrigin-RevId: 662164776