Make CUDAFuture remember and restore current device in callback (#48789)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48789
CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).
However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior _worse_. I'd love to hear people's opinions on this.
ghstack-source-id: 118081338
Test Plan: Unit tests
Reviewed By: mrshenli
Differential Revision: D25210335
fbshipit-source-id: 1d1a3f80b1cc42e5114bc88554ed50617f1aaa90