Add XLA lowering for xmap
This should allow us to try out xmap not only in a simulation (i.e.
faking the devices using vmap, which we still support), but also on real
hardware.
Limitations:
- No compilation caching yet
- Nested xmaps not supported yet
- Transforms (AD, vmap, etc.) of xmaps not supported yet
Benefits:
- An xmap over multiple mesh axes already implements a more efficient
lowering than the one used for nested pmaps.
The `resources` context-manager is now called `fake_resources`, while
real meshes can be defined in a specific context using the
`mesh(devices, axis_names)` manager. `devices` is supposed to be an
`ndarray` of JAX device objects (e.g. obtained from `jax.devices()`),
while `axis_names` should be a tuple of length matching the rank of
`devices` and specifying mesh axis names.
For concrete examples see the changes in `gmap_tests.py`.
In principle the current version of the code should also work in a
multi-host setting, but I haven't tested it just yet.