pytorch
4c8183f5 - Update on "[WIP] enable static graph training in DDP"

Commit
3 years ago
Update on "[WIP] enable static graph training in DDP" This PR provides enable static graph training when users call _set_static_graph(). This can help support more use cases in DDP without performance regression, also can potentially improve performance when there are unused parameters in the graph. 1. first iteration records graph states like how many times a grad is calculated, whether the grad is used or not. then first iteration queues a delay_all_reduce call back to all reduce grads. 2. Since autograd call back is associated with current target graph task, the delay_all_all call back should be associated with out-most backward graph task. A DDP sink layer is added in DDP forward loop so that we can queue the delay_all_reduce call back in the sink layer. 3. after first iterations, DDP will use the saved graph states to determine whether a grad is used or not. whether a grad is ready for communication. 4. rebuilt bucket is called in second iteration, after graph states are recorded in first iteration. 5. if the graph states change, DDP will throw errors Differential Revision: [D27539964](https://our.internmc.facebook.com/intern/diff/D27539964/) [ghstack-poisoned]
Author
Loading