Updates autograd engine to respect streams set in forward (#8354)
Summary:
This PR addresses issue https://github.com/pytorch/pytorch/issues/7601.
Currently models that use streams explicitly in forward have to do a lot of extra work to make backwards respect those streams. This PR extends the (recently added) input tracing (see TypeAndShape) to record the devices and streams of inputs. The autograd engine then uses this metadata to enact the expected stream parallelism without extra work from the user.
For example, a model with forward declared like (original example courtesy of ngimel):
```
def forward(self,x):
x0 = x.clone()
torch._C._cuda_setStream(self.stream1._cdata)
y0 = self.fc1(x0)
self.event1.record(stream = torch.cuda.current_stream())
torch._C._cuda_setStream(self.stream2._cdata)
y1 = self.fc2(x)
self.event2.record(stream = torch.cuda.current_stream())
self.stream2.wait_event(self.event1)
return y0 + y1
```
currently will backward on a single stream. With this change the kernels will go on the streams they are assigned in forward and both forward and backward will (for appropriate sizes) run the fc1 and fc2 kernels simultaneously.
The crux of this change is, as mentioned, an expansion of the TypeAndShape tracing and a relatively simple change to the autograd engine to use cuda events for stream synchronization. To make this efficient I also added a new AutoGPUAndStream class, exposed getting and setting streams on devices, and removed InputBuffer's AutoGPU (it's now redundant). While making these modifications I also fixed AutoGPU to check before setting the GPU when it's destroyed and to use THCudaCheck instead of its custom error handler. These changes mean that an often excessive cudaSetDevice() is not being called when inputs are added to a buffer.
In addition to allowing users to easily set and use streams that are respected in both forward and backward, this change may encourage modules to do the same and the expanded tracing might allow further optimizations in the autograd engine. (apaszke, for example, now after initial enumeration we know the number of devices that will be used by a graph task, which might help provide a sense of the "level of parallelism" we should expect.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8354
Test Plan: Two tests were added specifically for this behavior.
Differential Revision: D17275980
Pulled By: mruberry
fbshipit-source-id: 92bd50ac782ffa973b159fcbbadb7a083802e45d