Fix final callbacks for reentrant backwards (#35066)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35066
Closes #24965
Prior to this commit, final_callbacks_ are cleared on exit of ANY
backward. When using reentrant backward, the last backward would
remove all callbacks from the engine. However, this might lead to
unexpected behavior. For example, the application could install
a final callback after forward, and expecting this callback to fire
when all gradients are ready. If there is a renentrant backward on
a subgraph, it would fire the callback and delete it on exit,
meaning that when fired, not all gradients are ready.
**Failed Attempt**
The 1st attempt was trying to move the callback to the GraphTask
in engine::execute(). However, this failed because more callbacks
could be installed during backward pass.
**Current Solution**
Final callbacks are stored as a member variable in the GraphTask.
* Insertion: use the thread_local current_graph_task to find the
target GraphTask, and append final callback.
* Deletion: final callbacks have the same lifetime as a GraphTask
* Execution: Use the GraphTask provided in the argument to find
final callbacks.
Test Plan: Imported from OSS
Differential Revision: D20546474
Pulled By: mrshenli
fbshipit-source-id: d3f3449bb5af9f8703bcae63e6b52056cd535f11