Update base for Update on "RFC: Delete ProxyTensor wrapper subclass"
I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced. I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."
This PR is the result. It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API. There is no more wrapping. To
store the Proxy objects necessary for actually doing tracing, I maintain
a weak map mapping tensors to trace objects; this is equivalent to
storing the properties on the tensors themselves. (Note: I never
clean up old entries from the map at the moment, this is an easy fix
following how meta converter does it.)
Benefits of doing this:
* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
we never create new tensors (except when we call the underlying func),
so you don't have to worry about accidentally tracing them.
* No more syncing up metadata from in place operators. In particular
https://github.com/pytorch/pytorch/issues/81526 is mooted
* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.
* No more schlepping symbolic integers from the inner fake tensor to the
outer proxy tensor. If you can make a fake tensor with symbolic ints,
you're done, nothing else to do.
To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before. A more optimized implementation is possible.
I haven't finished this PR, I need to:
* Fix symbolic tracing of SymInts. I don't want to create SymIntProxies
anymore, so I have to make it so that SymInts support mode-style
tracing and then run the same playbook on them.
* Fix decomposition interpreter, which is still doing old style
ProxyTensor (and anyone else in the same boat)
* Rebase on top of https://github.com/pytorch/pytorch/pull/83215/ as
this trick only works if trace_factory_functions is always True
Signed-off-by: Edward Z. Yang <ezyangfb.com>
[ghstack-poisoned]