jax
ae835b74 - Add jax.devices() and friends, and add `devices` arg to pmap.

Commit
6 years ago
Add jax.devices() and friends, and add `devices` arg to pmap. This change adds the following APIs: * jax.devices(). This returns a list of available Device subclass instances. * jax.host_id(). Currently always 0, but will be useful on multi-host platforms. * jax.local_device_count(). Currently always equal to jax.device_count(), but will be useful on multi-host platforms. * Optional `devices` argument to pmap. This can be used to specify which devices should be used in the replicated computation.
Author
Committer
Parents
Loading