jax
2494e0c3 - Add XLA lowering for xmap

Commit
5 years ago
Add XLA lowering for xmap This should allow us to try out xmap not only in a simulation (i.e. faking the devices using vmap, which we still support), but also on real hardware. Limitations: - No compilation caching yet - Nested xmaps not supported yet - Transforms (AD, vmap, etc.) of xmaps not supported yet Benefits: - An xmap over multiple mesh axes already implements a more efficient lowering than the one used for nested pmaps. The `resources` context-manager is now called `fake_resources`, while real meshes can be defined in a specific context using the `mesh(devices, axis_names)` manager. `devices` is supposed to be an `ndarray` of JAX device objects (e.g. obtained from `jax.devices()`), while `axis_names` should be a tuple of length matching the rank of `devices` and specifying mesh axis names. For concrete examples see the changes in `gmap_tests.py`. In principle the current version of the code should also work in a multi-host setting, but I haven't tested it just yet.
Author
Committer
Parents
Loading