xla
16fa65f4 - [DDP] Add a correctness test case (#3980)

Commit
3 years ago
[DDP] Add a correctness test case (#3980) Summary: This pull request adds a correctness test case that compare the parameters after each training step between a DDP + TPU setup and a single CPU setup. If the inputs and labels are sharded correctly, the stepped parameters should match each. Test Plan: PJRT_DEVICE=TPU python test/pjrt/test_ddp.py TestPjRtDistributedDataParallel.test_ddp_correctness
Author
Parents
Loading