[SPMD] Add FSDP sharding for test_train_spmd_linear_model.py (#5299)
Summary:
This diff adds FSDP sharding for test_train_spmd_linear_model.py.
Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_train_spmd_linear_model.py --sharding fsdp