Improve mem efficiency of constant folding (#108421)
Couple changes to make it more efficient.
- Because we replacing nodes that only have a single value, only store a single value instead of the whole tensor for node replacement
- torch.fx.Interpreter will preserve a Tensor in the env as long as it has more uses. That also applies even to output uses, but we are not going to constant fold that use. Instead of using last use for garbage collection, use last non output use.
If reviewers would prefer I ghstack this bc of code movement let me know.
Fix for https://github.com/pytorch/pytorch/issues/108388
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108421
Approved by: https://github.com/jansel