Factor input deduplication into a separate function (#89701)
It turns out that instead of having a giant blobby aot_dispatch_autograd
function, we can factor it into a series of wrapper functions, each
of which successively guarantees more invariants on the inner
compilation function until the final inner function is quite trivial.
How exactly you have to wrap the input user functions and the output
compiled functions can be expressed concisely in Haskell, so I've
included the Haskell formulation in code comments.
This PR shows how to do this for input deduplication. Dealing with the
rest of the view handling is left to future work.
This PR should also be a slight performance improvement as deduplicating
is skipped entirely when there are no duplicate inputs.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89701
Approved by: https://github.com/bdhirsh