xla
0f23514b - [SPMD][PoC] compile & execute with PjRt (#3684)

Commit
2 years ago
[SPMD][PoC] compile & execute with PjRt (#3684) * Create gspmd test * Add experimental XLAShardedTensor and mark_sharding API * Add ShardingSpec annotation to XLA tensor * Add partitioining test * Update sharding spec to support full replication & mesh sharding * Add testing probe for partitioner * Add spmd partitioner dependency * Tensor sharding annotation and sharded HLO dumping function. * Add checks for _xla_mark_sharding * Compile tensor ops with sharding annotations * Make SpmdPartitioningPass in tensor.Compile() * Use sharding custom_call & add more comments * is_spmd device assignment for xrt_computation * Disable GPU for SPMD * Rebasing master with ltc migration changes * CreateXrtSpmdComputation only if spmd is enalbed in HloModule * Remove xrt changes before landing the experimental feature. * Refactor experimental support for SPMD changes * Update sharding spec to support full replication & mesh sharding * Tensor sharding annotation and sharded HLO dumping function. * Add checks for _xla_mark_sharding * Compile tensor ops with sharding annotations * Make SpmdPartitioningPass in tensor.Compile() * Rebasing master with ltc migration changes * CreateXrtSpmdComputation only if spmd is enalbed in HloModule * PjRt compile partitioned graph with SPMD partitioning option * Introduce ExecuteReplicated in pjrt_computation_client * * Add ShardingUtil::InputHandler for input sharding * Add `SPMD` XlaDeviceType & GetVirtualDevice() * Add PjRtComputationClient::PjRtShardedData * Add more unit tests to XLAShardingTest (C++) * Add more unit tests to XLAShardingTest (Python) * Allow `_xla_mark_sharding` to initiate sharded data transfer * Refactor device transfers to use `BufferFromHostBuffer` * Replace InlinedVector in TransferToServer * * Remove kEnvSpmdTest flag * Fix mark_sharding bugs * Allow XLATensor::SetShardingSpec to receive ShardingSpecPtr * Use unpartitioned tensor shape and device for PjRtShardedData. * Disable partial sharding in mark_sharding * Remove duplicate copies of sharding annotation * Allow ToHlo to return partitioned HLO if sharded * Fix lint errors * * Add/expand CreateTensorsData & InputHandler tests * Add device assignment for SPMD compilation * [SPMD] Refactor `_xla_partitioning_pass`. * [SPMD] Refactor `_xla_mark_sharding`. * [SPMD] Support higher-order mesh topology. * [SPMD] inherit global tensor requires_grad in XLAShardedTensor * [SPMD] Disable aliasing if is_spmd. * [SPMD] Use all devices instead of local * [SPMD] Define clear_sharding interface * [SPMD] experiment with IR sharding preservation. * Rebase master * Refactor and add comments Co-authored-by: Will Cromar <wcromar@google.com>
Author
Parents
Loading