xla
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
#15127
Merged

#sdy Initial set of changes to allow for lowering to the Shardy dialect. #15127

copybara-service merged 1 commit into main from test_648711460
copybara-service
copybara-service310 days ago

#sdy Initial set of changes to allow for lowering to the Shardy dialect.

The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with mhlo.shardings. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:

  1. Traced jaxpr
  2. Jaxpr -> StableHLO
  3. StableHLO with Shardy propagation
  4. StableHLO with Shardy partitioning
  5. StableHLO -> HLO
  6. XLA optimizations

The following test:

def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())

outputs:

module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}

Shardy will be hidden behind the jax_use_shardy_partitioner flag initially before becoming enabled by default in the future.

copybara-service copybara-service assigned bartchr808 bartchr808 310 days ago
copybara-service copybara-service force pushed from 2c1a86cb to 5ad9eeec 310 days ago
copybara-service copybara-service force pushed from 5ad9eeec to f60dd12f 310 days ago
copybara-service copybara-service force pushed from f60dd12f to f385a3bc 310 days ago
copybara-service copybara-service force pushed from f385a3bc to e74e1d73 310 days ago
copybara-service copybara-service force pushed from e74e1d73 to 927d8687 310 days ago
copybara-service copybara-service force pushed from 927d8687 to 80b9a72a 310 days ago
copybara-service copybara-service force pushed from 80b9a72a to 31e799ff 310 days ago
copybara-service copybara-service force pushed from 31e799ff to 287c906a 310 days ago
copybara-service copybara-service force pushed from 287c906a to 696dd5eb 307 days ago
copybara-service copybara-service force pushed from 696dd5eb to 545347c9 307 days ago
copybara-service copybara-service force pushed from 545347c9 to d6dcb04a 307 days ago
copybara-service copybara-service force pushed from d6dcb04a to d1afe17c 306 days ago
copybara-service copybara-service force pushed from d1afe17c to e74d98a0 306 days ago
bartchr808 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
4e56ae4c
copybara-service copybara-service force pushed from e74d98a0 to 4e56ae4c 306 days ago
copybara-service copybara-service merged 4e56ae4c into main 306 days ago
copybara-service copybara-service deleted the test_648711460 branch 306 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
Labels
Milestone