Fix Device Idx Setting (#97399)
We weren't always setting the device indices, which led to a StopIteration Exception on next(iter(device_idxs))
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97399
Approved by: https://github.com/yanboliang, https://github.com/ngimel