[CUDA graphs] Avoid sync errors when graph capturing cudnn rnn calls that use cudnn dropout (#56433)
Summary:
Cudnn rnn calls that use use cudnn dropout maintain a "state" buffer across calls. [DropoutState](https://github.com/pytorch/pytorch/blob/fe3f6f2da2cb2ddde1a277cd5e99f898933a3c5d/aten/src/ATen/native/cudnn/RNN.cpp#L1388-L1402)'s lock() and unlock() ensure the current call's use of the state buffer syncs with the end of the previous call's use of the state buffer (in case the previous call was on a different stream).
Telling a capturing stream to wait on an event recorded in a non-capturing stream is an error (1). Telling a non-capturing stream to wait on an event recorded during capture is also an error (2). So DropoutState's flow can error in either of two simple use cases:
```python
rnn = nn.LSTM(512, 512, 2, dropout=0.5).cuda()
out1 = rnn(in1)
# calling cudnn rnn with dropout in capture after calling it uncaptured triggers 1
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
graph.capture_begin()
out2 = rnn(in2)
graph.capture_end()
torch.cuda.current_stream().wait_stream(capture_stream)
# calling cudnn rnn with dropout uncaptured after calling it in capture triggers 2
out3 = rnn(in3)
```
This PR fixes both cases by telling `DropoutState::lock()`: "if the most recent end-of-usage event was in a different capture state (ie, we crossed a capturing<->noncapturing border) or in a different capture, don't sync on it." While considering the fix I had two assumptions in mind:
- only one capture using the RNN can be underway at a time in this process
- no noncapturing ops in this process are issuing RNN calls while the capture using the RNN is underway.
That second assumption seems brittle if, for example, someone wants to capture an internal region of the forward method of a model wrapped with DataParallel: multiple threads could be issuing RNN calls with some currently capturing and some not. We should talk about whether that use case seems realistic.
(Bigger-picture thoughts: I don't know if forcing calls to serialize on using the shared state buffer is the best design. And if we want to do it that way, we might as well run all cudnn rnns with dropout on a dedicated side stream synced with the surrounding stream (capturing or not), in which case I don't think this PR's event-handling diffs would be needed.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56433
Reviewed By: heitorschueroff
Differential Revision: D27966444
Pulled By: ezyang
fbshipit-source-id: fe0df843c521e0d48d7f2c81a17aff84c5497e20