generate device context managers in inductor code (#90934)
Fixes https://github.com/pytorch/torchdynamo/issues/1717, https://github.com/pytorch/torchdynamo/issues/1990
<s>TODO: add test with multiple devices, figure out extra context initialization</s>
Problems:
<s>It still initializes context on 0-th device that it shouldn't, I'll take a look where that happens and fix before landing</s>
It adds a python device context manages, that is absurdly slow and takes ~2.5 us (should be nanoseconds). That's not a problem for real models, because it'll be called just once, but it is a bit of an inconvenience for microbenchmarking, we should make that context manager more performant (won't fix in this PR)
It still can have bugs for graphs that run on multiple devices and can have buffers incorrectly shared between multiple device by memory reuse, if that happens that'll need to be solved separately.
Generated code:
```
def call(args):
arg0_1, arg1_1 = args
args.clear()
with torch.cuda.device(1):
buf0 = empty_strided((4, ), (1, ), device='cuda', dtype=torch.float32)
stream1 = get_cuda_stream(1)
triton_fused_div_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream1)
del arg0_1
del arg1_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90934
Approved by: https://github.com/wconstab