[TPU] Support PyTorch/XLA FSDP via SPMD (#28949)
* Initial commit
* Add guards for the global mesh
* Address more comments
* Move the dataloader into integrations/tpu.py
* Fix linters
* Make karg more explicitly
* Remove the move device logic
* Fix the CI
* Fix linters
* Re-enable checkpointing