Allow consumer ops to sync on GraphRoot's gradient (#45787)
Summary:
Currently, a GraphRoot instance doesn't have an associated stream. Streaming backward synchronization logic assumes the instance ran on the default stream, and tells consumer ops to sync with the default stream. If the gradient the GraphRoot instance passes to consumer backward ops was populated on a non-default stream, we have a race condition.
The race condition can exist even if the user doesn't give a manually populated gradient:
```python
with torch.cuda.stream(side_stream):
# loss.backward() implicitly synthesizes a one-element 1.0 tensor on side_stream
# GraphRoot passes it to consumers, but consumers first sync on default stream, not side_stream.
loss.backward()
# Internally to backward(), streaming-backward logic takes over, stuff executes on the same stream it ran on in forward,
# and the side_stream context is irrelevant. GraphRoot's interaction with its first consumer(s) is the spot where
# the side_stream context causes a problem.
```
This PR fixes the race condition by associating a GraphRoot instance, at construction time, with the current stream(s) on the device(s) of the grads it will pass to consumers. (i think this relies on GraphRoot executing in the main thread, before backward thread(s) fork, because the grads were populated on the main thread.)
The test demonstrates the race condition. It fails reliably without the PR's GraphRoot diffs and passes with the GraphRoot diffs.
With the GraphRoot diffs, manually populating an incoming-gradient arg for `backward` (or `torch.autograd.grad`) and the actual call to `autograd.backward` will have the same stream-semantics relationship as any other pair of ops:
```python
# implicit population is safe
with torch.cuda.stream(side_stream):
loss.backward()
# explicit population in side stream then backward in side stream is safe
with torch.cuda.stream(side_stream):
kickoff_grad = torch.ones_like(loss)
loss.backward(gradient=kickoff_grad)
# explicit population in one stream then backward kickoff in another stream
# is NOT safe, even with this PR's diffs, but that unsafety is consistent with
# stream-semantics relationship of any pair of ops
kickoff_grad = torch.ones_like(loss)
with torch.cuda.stream(side_stream):
loss.backward(gradient=kickoff_grad)
# Safe, as you'd expect for any pair of ops
kickoff_grad = torch.ones_like(loss)
side_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(side_stream):
loss.backward(gradient=kickoff_grad)
```
This PR also adds the last three examples above to cuda docs and references them from autograd docstrings.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45787
Reviewed By: nairbv
Differential Revision: D24138376
Pulled By: albanD
fbshipit-source-id: bc4cd9390f9f0358633db530b1b09f9c1080d2a3