Ensure proper syncs for out-of-place grad creation (torch.autograd.grad) when backward ops run on side streams (#60127)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/59844.
Streaming backwards collects "leaf streams" for AccumulateGrad functions that stash or accumulate .grad attributes for autograd leaf tensors, and syncs those streams with some ambient stream(s) so later ops can safely consume the grads on the ambient stream(s).
But, currently, streaming backwards does not collect leaf streams for grads produced out-of-place (ie, not stashed onto a .grad attribute) by `torch.autograd.grad`, because these out-of-place grads are "captured" and returned before they reach an AccumulateGrad function. Some out-of-place grads might not even have an AccumulateGrad function to go to, because `torch.autograd.grad` can be told to make grads for non-leaf temporaries.[1]
The upshot is, when streaming backwards makes ops that produce out-of-place gradients run on side streams, no ambient stream is told to sync on these side streams, so `torch.autograd.grad` doesn't offer the same post-call safe-use guarantees for grads as the leaf accumulation of `torch.autograd.backward`.
This PR ensures `torch.autograd.grad` gives the same safe-use guarantees as `torch.autograd.backward` by also stashing leaf streams for grads created out-of-place.
I augmented a streaming backwards test to include a torch.autograd.grad attempt. The test fails on current master[2] and passes with the engine.cpp diffs.
I have no idea if this bug or its fix matter to distributed autograd. pritamdamania mrshenli should take a look before it's merged.
[1] example:
```python
leaf = torch.tensor(..., requires_grad=True)
tmp = leaf * 2
loss = tmp.sum()
torch.autograd.grad(loss, inputs=(tmp, leaf))
```
Technically, because `torch.autograd.grad` can be told to produce grads for non-leaf temporaries, these streams might NOT be "leaf streams". Maybe I should rename `leaf_streams`?
[2] the way the test currently fails is fun: it reports
```
AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1.3e-06 and atol=1e-05, found 0 element(s) (out of 25) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.0 (5.0 vs. 5.0), which occurred at index (0, 0).
```
I suspect this [kafka trap](https://en.wiktionary.org/wiki/Kafkatrap) happens because assertEqual does a comparison test on the device, syncs on some bool result, sees failure and prints the tensors post-sync at which point is IS safe to access the values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60127
Reviewed By: mrshenli
Differential Revision: D29276581
Pulled By: albanD
fbshipit-source-id: a9f797e2fd76e2f884cce5a32ecf5d9b704c88ee