[PyTorch][Dist] Trigger pre/post hooks of output function nodes under distributed autograd (#34501)
Summary:
# Goals
Do the following things during a distributed backward pass.
1. Accumulate the gradient of a variable to RPC context once the gradient is ready instead of at the very end of the backward pass.
2. Run post/pre hooks installed in`AccumulateGrad` nodes once the gradient is ready for the variable. Currently, the hooks in `AccumulateGrad` are not executed just because the function `AccumulateGrad` itself is not even evaluated by the local engine.
3. Make it extensible to support post hooks installed by DDP's reducer.
# Introduce GradCapturePreHook
## Why do we need this?
### Root issue:
* dist engine uses the autograd.grad-like API on the vanilla engine and then in the Future callback populates the context with the gradients. This is a bad emulation of the .backward() call on the vanilla engine.
### Practical issue:
* The leaf’s hook are not called (because associated with the AccumulateGrad that is not call in the autograd.grad-like API). Modules like DDP rely on these hooks.
* The Future is marked as completed before the context is actually populated with the grads leading to unexpected behavior on the user side.
* The Future callback is only called at the complete end of the backward and so too late for DDP if they want to overlap compute/transfert.
### Proposed solution:
* Provide hooks in the autograd.grad-like API that will allow the distributed engine to populate the context and call the hooks to better emulate the .backward call.
## Who can install a grad capture pre-hook?
This will be an internal hook at C++ level and it won’t be exposed to PyThon code. Only call-sites directly interacting with the local engine can install such hooks.
## Signature
The returned `grad` will be captured.
```
virtual const torch::Tensor& grad operator()(const torch::Tensor& grads) = 0;
```
## Where are hooks installed?
Grad capture pre-hooks are install in GraphTask::ExecInfo::Capture. ExecInfo is per node. Every backward run will have its own GraphTask instance.
## When/How will hooks be called?
When the local engine captures the grads for a node, all grad capture pre hooks are called one by one in the order they are added. The output grads of the hooks will replace the original grads.
The output of the last hook will be used for grad capturing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34501
Test Plan:
All existing tests should pass.
```
python setup.py develop
python test/distributed/rpc/test_dist_autograd_spawn.py DistAutogradTestWithSpawn.test_post_hooks
```
Differential Revision: D20953673
Pulled By: hczhu
fbshipit-source-id: 543b3844823330ea9f9856bab7c5cb2679290a53