[SPMD] Support SPMDFullToShardShape (#6922)
Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.
To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape