[reland] Set streams when invoking UDFs (#59210)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59210
Reland of https://github.com/pytorch/pytorch/pull/58427
Running the UDF (be it Python or JIT) is the first step of (most?) RPC calls, which is where the inputs are consumed. The lazy stream context contains the streams used by the inputs, thus it must be made current before any UDF call. I opt to do this as "close" as possible to the place the UDF is invoked, to make the relationship as explicit as possible.
ghstack-source-id: 130202847
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D28623889
fbshipit-source-id: ed38242f813dac075d162685d52ae89f408932f9