Avoid DDP race condition with find_unused_parameters=True when all params are used (#53160)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/53159.
See comments for a description of the race condition. Thanks to ptrblck xwang233 and especially zasdfgbnm for lots of help isolating the problem and discussing the fix.
PRing for discussion. We can try to concoct a dedicated test for the problem if you want. The ingredients are:
- DDP(..., find_unused_parameters=True)
- Use all the DDP-ed model's params in forward such that the "lazy local used work wait()" path will be taken in backward
- Queue up a lot of asynchronous dummy work just before backward(), so stream work gets pushed far into the future relative to CPU work
Benchmark:
Bert model, When find_unused_parameters=true, latency (sec) per iteration P50: trunk-1.265sec, this PR-1.263sec, if add blocking copy before calling local_used_.fill(i)-1.236 sec
Bert model, When find_unsued_parameters=false, latency (sec) per iteration P50: trunk-1.00sec, this PR-1.026sec
Resnet50 model, accuracy is also matched with trunk when find_unused_parameters=true and false
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53160
Reviewed By: albanD
Differential Revision: D26916766
Pulled By: zhaojuanmao
fbshipit-source-id: 3e0ed91b7b5c42e2f2c82e12d4d2940fdc89e023