jax
12975bbc - [pmap] Add support for nested pmaps on multihost platforms via axis_size (#2002)

Commit
5 years ago
[pmap] Add support for nested pmaps on multihost platforms via axis_size (#2002) One issue with nested pmaps on multihost platforms is inferring the global pmap axis size without communication. This commit sidesteps the issue by adding an `axis_size` argument to manually provide this information. This change only enables a single cross-host pmap; all inner pmaps must be single-host. Addressing: #1753
Author
Committer
Parents
Loading