Disable functorch modes in testing's freeze_rng_state() (#81006)
freeze_rng_state() is this thing we use to test random operations in
OpInfos: it ensures that everytime the op is called the rng state is the
same.
Unfortunately this doesn't work with functorch, because
- torch.cuda.set_rng_state() clones a Tensor and then grabs its data_ptr
- functorch's modes cause functorch wrappers to get emitted on the
.clone() call (even if the thing being cloned a regular Tensor).
Tensor subclasses also had this problem. This PR applies the same
solution as torch_dispatch did before: we're just going to disable
functorch dispatch when setting the rng state.
In the long run, torch_dispatch should probably have an option to
interpose on torch.cuda.set_rng_state or generator.set_state... but I
didn't want to think very hard right now.
Test Plan:
- tested with functorch tests (those tests were previously being
skipped, now I can unskip some of them).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81006
Approved by: https://github.com/samdow