jax
ef1045bd - Add support for nested xmaps

Commit
5 years ago
Add support for nested xmaps With the caveat that the outer xmap cannot use the `vectorize` resource mapping, because the batching rule for xmap is not implemented yet (this is going to be a follow up). But as long as the outer xmap only uses real devices the nesting should work fine. Note that we don't try to hard to handle perfectly nested `xmap`s efficiently. But this shouldn't be an issue, because it is usually trivial to flatten them on the user side. Very similar logic can now be used to implement a `sharded_jit` version that composes well with `xmap` (although it is unclear if it's an API that we want to have in the long run).
Author
Committer
Parents
Loading