Handle transitive replacements in Triton kernel mutation analysis (#121867)
Summary: Previously, we didn't handle transitive replacements in MLIR walk-based function info mining in the Triton kernel mutation analysis pass. As a result, for the TTIR below:
```
tt.func private @cumsum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("...":296:0)) -> tensor<1x16xf32> attributes {noinline = false} {
%0 = "tt.scan"(%arg0) <{axis = 1 : i32, reverse = false}> ({
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
%1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc16)
tt.scan.return %1 : f32 loc(#loc16)
}) : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc16)
tt.return %0 : tensor<1x16xf32> loc(#loc18)
} loc(#loc15)
```
the mined function dict looked like this:
```
{Intermediate(idx=25): [Op(name='tt.call',
fn_call_name='_sum_combine__fp32_fp32__',
args=[Intermediate(idx=26),
Intermediate(idx=26)])],
Intermediate(idx=27): [Op(name='tt.scan.return',
fn_call_name=None,
args=[Intermediate(idx=25)])],
Intermediate(idx=-4): [Op(name='tt.return',
fn_call_name=None,
args=[Intermediate(idx=27)])]}
```
whereas it should look like this (not the `Param(idx=0)` arguments of the `tt.call`):
```
{Intermediate(idx=25): [Op(name='tt.call',
fn_call_name='_sum_combine__fp32_fp32__',
args=[Param(idx=0),
Param(idx=0)])],
Intermediate(idx=27): [Op(name='tt.scan.return',
fn_call_name=None,
args=[Intermediate(idx=25)])],
Intermediate(idx=-4): [Op(name='tt.return',
fn_call_name=None,
args=[Intermediate(idx=27)])]}
```
This is fixed in the PR.
Test Plan:
```
$ python test/inductor/test_triton_kernels.py -k test_cumsum
.
----------------------------------------------------------------------
Ran 1 test in 1.771s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121867
Approved by: https://github.com/oulgen