[JIT] Adding a concat optimization pass (#55474)
Summary:
This PR adds a new pass in JIT that optimizes `aten::cat` ops.
Specifically, here are optimizations performed:
* Eliminate redundant in `cat` inputs by performing cse on the list of inputs.
- This includes eliminating fully redundant `cat` ops when all the inputs are the same as well the case when "all but one" of the inputs have already been concatenated.
* Expand `cat` into multiple copies and eliminate redundancies.
- This also includes eliminating redundancies in the underlying buffers used for `cat`.
These optimizations are not enabled in any compilation flow at this point.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55474
Reviewed By: albanD
Differential Revision: D27624511
Pulled By: navahgar
fbshipit-source-id: d509289fafc23e73b02f64a90219148896817339