Respect dist autograd context in torch.jit._fork. (#34360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34360
The distributed autograd context sets up a thread local context id
which is used to perform appropriate book keeping and autograd recording of RPC
functions in the forward pass.
However, if we use torch.jit._fork within the distributed autograd context, the
code executed within torch.jit._fork will lose this context since it is run in
a separate JIT thread and the thread local is not set in that thread.
To fix this problem, we pass in the distributed autograd context to
torch.jit._fork similar to what we did in
https://github.com/pytorch/pytorch/pull/16101.
ghstack-source-id: 100445465
Test Plan: waitforbuildbot
Differential Revision: D20301352
fbshipit-source-id: aa3fffe69c2b40722c66213351a4e0d77484a621