jax
8f4ba7e6 - Allow specifying both `devices` and `axis_size` to pmap. (#3475)

Commit
5 years ago
Allow specifying both `devices` and `axis_size` to pmap. (#3475) This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform.
Author
Parents
Loading