Restore current streams on dst device after switching streams (#17439)
Summary:
When switching back to `d0` from a stream on a different device `d1`, we need to restore the current streams on both `d0` and `d1`. The current implementation only does that for `d0`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17439
Differential Revision: D14208919
Pulled By: mrshenli
fbshipit-source-id: 89f2565b9977206256efbec42adbd789329ccad8