jax
e62a50cd - #sdy add JAX Shardy support for shard_map.

Commit
1 year ago
#sdy add JAX Shardy support for shard_map. For example the following JAX program: ```py devices = np.array(jax.devices()[:8]) mesh = Mesh(devices, axis_names=('x')) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) print(jax.jit(fwd).lower(a).as_text()) ``` prints: ```cpp module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=8]> func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) { %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32> sdy.return %1 : tensor<1x8xi32> } : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } } ``` PiperOrigin-RevId: 679165100
Author
Parents
Loading