jax
332a9ba1 - Fix axis_index inside nested pmaps

Commit
5 years ago
Fix axis_index inside nested pmaps The previous translation rule has assumed that `axis_index` is always taken over the outermost axis in the `axis_env`, and was always producing the same output, no matter which axis has been specified. This fixes the translation rule to start taking the `axis_name` into account. Additionally, this adds support for querying the index along multiple axes, which will be useful for `gmap`.
Author
Committer
Parents
Loading