change with_callable_args to return a fresh _PartialWrapper (#63374)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63326
Currently `get_callable_args` has the side effect of mutating the input _PartialWrapper. When that input is one of the global defaults, there are all sorts of lifetime issues that crop up. (Details in the linked issue.) So far as I can tell, we only need to make a constructor which is module (and by extension device) aware, so making a fresh one should have the same effect without leaking the last call's module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63374
Test Plan: the repro in https://github.com/pytorch/pytorch/issues/63326 now reports no leaked Tensors, and all quantization tests pass locally.
Reviewed By: HDCharles
Differential Revision: D30359360
Pulled By: robieta
fbshipit-source-id: aef33261ac49952d8d90da868a57ab063dfc456e