jax
1c82484c - Start a new TPU interpret mode for Pallas.

Commit
348 days ago
Start a new TPU interpret mode for Pallas. The goal of this interpret mode is to run a Pallas TPU kernel on CPU, while simulating a TPU's shared memory, multiple devices/cores, remote DMAs, and synchronization. The basic approach is to execute the kernel's Jaxpr on CPU, but to replace all load/store, DMA, and synchronization primitives with io_callbacks to a Python functions that simulate these primitives. When this interpret mode is run inside of shard_map and jit, the shards will run in parallel, simulating the parallel execution of the kernel on multiple TPU devices. The initial version in this PR can successfully interpret the examples in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html , but is still missing a lot of functionality, including: - Executing DMAs asynchronously. - Padding in pallas_call. - Propagating source info.
Author
Committer
Parents
Loading