support load and save checkpoint in torchax (#9616)
This PR supports checkpointing with torchax:
1. load a checkpoint file in torch tensors and convert to Jax arrays; Or
load a checkpoint file in Jax arrays
2. save a checkpoint file in Jax arrays.
This support single worker now.