jax
6ffde806 - Implement pmap of sharded_jit (#3144)

Commit
5 years ago
Implement pmap of sharded_jit (#3144) * Implement pmap of sharded_jit * Update jax/interpreters/pxla.py Co-authored-by: James Bradbury <jekbradbury@google.com> * Address comments Co-authored-by: James Bradbury <jekbradbury@google.com>
Author
Loading