xla
6341b5c0 - Shard datasets for distributed training. (#997)

Commit
6 years ago
Shard datasets for distributed training. (#997) * Shard datasets for distributed training. Added torch.utils.data.distributed.DistributedSampler usage to our MNIST, CIFAR, and IMAGENET examples and tested for training accuracy improvements with larger LR and that each shard had a different loss (data parallelism). Verified no performance regressions on v3-8/n1-standard-64: * Baseline (no sharding): https://gist.github.com/jysohn23/8077aab85474ef3e7bfcb54851280d8c * Sharding (single-shard): https://gist.github.com/jysohn23/e0900d888534598f06b2eec3da598d14 (May suffer from lower accuracy since larger global batch sizes; all use real data) MNIST: https://gist.github.com/jysohn23/0ba2d58c2da553828c29461e89f8d1da CIFAR: https://gist.github.com/jysohn23/0b6b57e50007e57189b35032ed654a4f IMAGENET (ResNet50): https://gist.github.com/fb5a50abea01b273fc55599ce6994559
Author
Parents
Loading