[SPMD] Support SPMDShardToFullShape (#6925)
Summary:
This pull request enables SPMDShardToFullShape. The trickiest part is how to get the full shape, and here is a couple of options:
1. Bookkeeping the shape full shape that enters SPMDFullToShardShape. This is not selected given the output could be created on the fly.
2. Constructing the full shape from the local shard and the sharding spec. This is not selected given there is no way to deal with the padding. We can't examine the data during the tracing time.
3. Let users pass the full shape in. This is selected because it's just the most sounded path.
Tes Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_e2e -k test_spmd_shard_to_full_shape