Fix TensorPipeAgent shutdown to ensure it drains all outstanding work. (#40060)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40060
As part of debugging https://github.com/pytorch/pytorch/issues/39855,
I noticed that TensorPipeAgent's ThreadPool was still executing tasks when the
python interpreter was shutting down. This caused issues with
pybind::gil_scoped_acquire() since it can't be called when the interpreter is
shutting down resulting in a crash.
The reason for this was that TensorPipeAgent was calling waitWorkComplete and
then shutting down the listeners. This meant that after waitWorkComplete
returned, there could still be a race where an RPC call gets enqueued before we
shutdown listeners.
To avoid this situation, I've moved the call to waitWorkComplete at the end of
shutdown (similar to ProcessGroupAgent).
Closes: https://github.com/pytorch/pytorch/issues/39855
ghstack-source-id: 105926653
Test Plan:
1) Ran test_backward_node_failure
(__main__.TensorPipeAgentDistAutogradTestWithSpawn) 100 times to verify the
fix.
2) waitforbuildbot
Differential Revision: D22055708
fbshipit-source-id: 2cbe388e654b511d85ad416e696f3671bd369372