Optimize split-split pass (#100983)
Summary:
Previously, we were replacing all getitems of a split - even the ones not affected by the pattern. For large split nodes, this was inefficient.
For instance, on an internal ads model - split-split pass took ~1100s. This is down to ~18s after this optimization
Test Plan:
* Compiled and tested on internal model (compilation time down by ~1100s)
* CI tests
Differential Revision: D45698034
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100983
Approved by: https://github.com/jansel