flax
e637bb39 - [pmap no rank reduce cleanup]: When flipping the

Commit
1 year ago
[pmap no rank reduce cleanup]: When flipping the jax_pmap_no_rank_reduction flag imagenet/train_test.py fails because it ooms while checkpointing. Just change the test to keep all parameters replicated always. PiperOrigin-RevId: 662659895
Author
Committer
Parents
Loading