jax
9ab07d85 - support axis_index_groups in psum(const) (#4070)

Commit
5 years ago
support axis_index_groups in psum(const) (#4070) * support axis_index_groups in psum(const) * add test for psum(constant, axis_index_groups) * rm trailing whitespace * Update lax_parallel.py
Author
Parents
Loading