jax
061f4df8 - Make `device_put` work with inputs which are host local and the sharding is global sharding i.e. sharding spanning across multiple hosts.

Commit
1 year ago
Make `device_put` work with inputs which are host local and the sharding is global sharding i.e. sharding spanning across multiple hosts. Use `multihost_utils.assert_equal` to check if the input is the same across all hosts. Do some formatting fixes too ;) PiperOrigin-RevId: 647711853
Author
Committer
Parents
Loading