Set streams when invoking UDFs (#58427)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58427
Running the UDF (be it Python or JIT) is the first step of (most?) RPC calls, which is where the inputs are consumed. The lazy stream context contains the streams used by the inputs, thus it must be made current before any UDF call. I opt to do this as "close" as possible to the place the UDF is invoked, to make the relationship as explicit as possible.
ghstack-source-id: 129567052
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28474983
fbshipit-source-id: 358292764d0a6832081c34bf6736f0961475ff3d