[inductor] Add cat + split_with_sizes elimination pass (#107956)
Summary:
When the `cat` inputs' sizes and the `split_sizes` of the downstream `split_with_sizes` match, the `cat` + `split_with_sizes` constellation can be eliminated. E.g. here:
```
@torch.compile
def fn(a, b, c):
cat = torch.ops.aten.cat.default([a, b, c], 1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(cat, [2, 3, 5], 1)
return [s ** 2 for s in split_with_sizes]
inputs = [
torch.randn(2, 2, device="cuda"),
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
output = fn(*inputs)
```
This PR adds a new fx pass for such elimination. The new pass is similar to the existing [`splitwithsizes_cat_replace`](https://github.com/pytorch/pytorch/blob/b18e1b684a7673daa3a51128aae4e75ed7aa7cbc/torch/_inductor/fx_passes/post_grad.py#L508), but considers the ops in the opposite order.
Test Plan:
```
$ python test/inductor/test_pattern_matcher.py
...
----------------------------------------------------------------------
Ran 21 tests in 46.450s
OK
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107956
Approved by: https://github.com/jansel