pytorch
26582056 - Disable functorch modes in testing's freeze_rng_state() (#81006)

Commit
2 years ago
Disable functorch modes in testing's freeze_rng_state() (#81006) freeze_rng_state() is this thing we use to test random operations in OpInfos: it ensures that everytime the op is called the rng state is the same. Unfortunately this doesn't work with functorch, because - torch.cuda.set_rng_state() clones a Tensor and then grabs its data_ptr - functorch's modes cause functorch wrappers to get emitted on the .clone() call (even if the thing being cloned a regular Tensor). Tensor subclasses also had this problem. This PR applies the same solution as torch_dispatch did before: we're just going to disable functorch dispatch when setting the rng state. In the long run, torch_dispatch should probably have an option to interpose on torch.cuda.set_rng_state or generator.set_state... but I didn't want to think very hard right now. Test Plan: - tested with functorch tests (those tests were previously being skipped, now I can unskip some of them). Pull Request resolved: https://github.com/pytorch/pytorch/pull/81006 Approved by: https://github.com/samdow
Author
Committer
Parents
Loading