[SPMD][PoC] compile & execute with PjRt (#3684)
* Create gspmd test
* Add experimental XLAShardedTensor and mark_sharding API
* Add ShardingSpec annotation to XLA tensor
* Add partitioining test
* Update sharding spec to support full replication & mesh sharding
* Add testing probe for partitioner
* Add spmd partitioner dependency
* Tensor sharding annotation and sharded HLO dumping function.
* Add checks for _xla_mark_sharding
* Compile tensor ops with sharding annotations
* Make SpmdPartitioningPass in tensor.Compile()
* Use sharding custom_call & add more comments
* is_spmd device assignment for xrt_computation
* Disable GPU for SPMD
* Rebasing master with ltc migration changes
* CreateXrtSpmdComputation only if spmd is enalbed in HloModule
* Remove xrt changes before landing the experimental feature.
* Refactor experimental support for SPMD changes
* Update sharding spec to support full replication & mesh sharding
* Tensor sharding annotation and sharded HLO dumping function.
* Add checks for _xla_mark_sharding
* Compile tensor ops with sharding annotations
* Make SpmdPartitioningPass in tensor.Compile()
* Rebasing master with ltc migration changes
* CreateXrtSpmdComputation only if spmd is enalbed in HloModule
* PjRt compile partitioned graph with SPMD partitioning option
* Introduce ExecuteReplicated in pjrt_computation_client
* * Add ShardingUtil::InputHandler for input sharding
* Add `SPMD` XlaDeviceType & GetVirtualDevice()
* Add PjRtComputationClient::PjRtShardedData
* Add more unit tests to XLAShardingTest (C++)
* Add more unit tests to XLAShardingTest (Python)
* Allow `_xla_mark_sharding` to initiate sharded data transfer
* Refactor device transfers to use `BufferFromHostBuffer`
* Replace InlinedVector in TransferToServer
* * Remove kEnvSpmdTest flag
* Fix mark_sharding bugs
* Allow XLATensor::SetShardingSpec to receive ShardingSpecPtr
* Use unpartitioned tensor shape and device for PjRtShardedData.
* Disable partial sharding in mark_sharding
* Remove duplicate copies of sharding annotation
* Allow ToHlo to return partitioned HLO if sharded
* Fix lint errors
* * Add/expand CreateTensorsData & InputHandler tests
* Add device assignment for SPMD compilation
* [SPMD] Refactor `_xla_partitioning_pass`.
* [SPMD] Refactor `_xla_mark_sharding`.
* [SPMD] Support higher-order mesh topology.
* [SPMD] inherit global tensor requires_grad in XLAShardedTensor
* [SPMD] Disable aliasing if is_spmd.
* [SPMD] Use all devices instead of local
* [SPMD] Define clear_sharding interface
* [SPMD] experiment with IR sharding preservation.
* Rebase master
* Refactor and add comments
Co-authored-by: Will Cromar <wcromar@google.com>