jax
25262edb - Make pmap lax.psum(1, 'i') and pxla.axis_index('i') work

Commit
6 years ago
Make pmap lax.psum(1, 'i') and pxla.axis_index('i') work The implementation mechanism is to use a bit of dynamic context to model the axis name environment at trace time, and for the environment to track how an axis name maps to an axis size and the corresponding trace (i.e. the JaxprTrace instance). With that information, we can lift special primitives that take axis_name parameters into the trace as needed without having a data dependence on the input.
Author
Committer
Parents
Loading