[DDP] Add a correctness test case (#3980)
Summary:
This pull request adds a correctness test case that compare the parameters
after each training step between a DDP + TPU setup and a single CPU setup.
If the inputs and labels are sharded correctly, the stepped parameters
should match each.
Test Plan:
PJRT_DEVICE=TPU python test/pjrt/test_ddp.py TestPjRtDistributedDataParallel.test_ddp_correctness