Support no sharding config
supporting no sharding config to make it similar to DDP algorithm
Benchmarked this version of ddp vs PyTorch C++ ddp, the perf gap is small and depends on the wrapping strategy. For 1GB bert model on 32 gpus with slow network connected.
if the bucket size is small like 25MB for PT DDP and wrapping min size is like 28MB for this version of ddp, this version of ddp has around 6% perf regression. The difference is because this version of ddp wrapping needs one more all reduce, not due to the python context switch;
if the bucket size is large like 40MB for PT DDP and wrapping min size is like 40MB for this version of ddp, this version of ddp has around 17% perf regression. because the last FSDP unit in this version of DDP wrapping has large delay to kick off the first all reduce; The difference is not due to the python context switch
if the bucket size larger than the model size, both PT DDP and this version of DDP will have single all reduce, they have the similar performance. That means python context switch is not a big concern again.
Overall, I think if the wrapping can be done well in this version of DDP and aligned with PT DDP bucketing orders, then the performance will be comparable. As you can see, we still need to improve our auto wrapping policy for both no_shard and fsdp strategy.
Once the auto wrapping issues is resolved for this API overall, it is promising to make this API back compatible with PT C++ DDP and merge them in the long run.
Also this API provides flexibility to mix DDP with other data parallelisms for a module, e.g FSDP(FSDP(submodule1, sharding_strategy=FULL_SHARDED), FSDP(submodule2, sharding_strategy=NO_SHARD), FSDP(submodule3, sharding_strategy=HSDP), submodule4, sharding_strategy=NO_SHARD)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76736
Approved by: https://github.com/rohan-varma