jax
ae8da833 - Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.

Commit
1 year ago
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call. This allows code like this: ```python def f(x): mesh = pltpu.create_tensorcore_mesh('core') y = jnp.zeros_like(x) @state_discharge.run_state def inner(refs): x_ref, y_ref = refs def kernel(): def alloc(sem): pltpu.async_copy(x_ref, y_ref, sem).wait() pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA) shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None, check_rep=False)() _, y = inner((x, y)) return y ``` Why? pallas_call as an API has a lot of responsibilities: 1. Creating Refs out of Arrays 2. Parallelizing execution over cores (via dimension_semantics and grid) 3. Pipelining 4. Allocating scratch spaces 5. Scalar prefetch This change allows you to express pallas_call *compositionally* using existing APIs. 1. Creating Refs out of arrays -> run_state 2. Parallelizing execution over cores -> shmap w/ a special mesh 3. Pipelining -> emit_pipeline 4. Allocating scratch spaces (run_scoped, which we could generalize to run_state) 5. Scalar prefetch -> run_scoped + a DMA The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support. PiperOrigin-RevId: 655320587
Author
Committer
Parents
Loading