pytorch
e4bc785d - randperm: add torch check to ensure generator device = tensor device (#47022)

Commit
4 years ago
randperm: add torch check to ensure generator device = tensor device (#47022) Summary: **BC-breaking Note:** This PR disallows passing in a generator of a different device than the tensor being created during `randperm` execution. For example, the following code which used to work no longer works. ``` > torch.randperm(3, device='cuda', generator=torch.Generator(device='cpu')) tensor([0, 1, 2], device='cuda:0') ``` It now errors: ``` > torch.randperm(3, device='cuda', generator=torch.Generator(device='cpu')) RuntimeError: Expected a 'cuda:0' generator device but found 'cpu' ``` **PR Summary:** Fixes https://github.com/pytorch/pytorch/issues/44714 Also added + ran tests to ensure this functionality. Disclaimer: More work needs to be done with regards to small cuda tensors when a generator is specified, look at the issue thread for more details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47022 Reviewed By: samestep Differential Revision: D24608237 Pulled By: janeyx99 fbshipit-source-id: b83c47219c7816d93f938f7ce86dc8857513961b
Author
Parents
Loading