pytorch
3eb322ff - Handle transitive replacements in Triton kernel mutation analysis (#121867)

Commit
1 year ago
Handle transitive replacements in Triton kernel mutation analysis (#121867) Summary: Previously, we didn't handle transitive replacements in MLIR walk-based function info mining in the Triton kernel mutation analysis pass. As a result, for the TTIR below: ``` tt.func private @cumsum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("...":296:0)) -> tensor<1x16xf32> attributes {noinline = false} { %0 = "tt.scan"(%arg0) <{axis = 1 : i32, reverse = false}> ({ ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc16) tt.scan.return %1 : f32 loc(#loc16) }) : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc16) tt.return %0 : tensor<1x16xf32> loc(#loc18) } loc(#loc15) ``` the mined function dict looked like this: ``` {Intermediate(idx=25): [Op(name='tt.call', fn_call_name='_sum_combine__fp32_fp32__', args=[Intermediate(idx=26), Intermediate(idx=26)])], Intermediate(idx=27): [Op(name='tt.scan.return', fn_call_name=None, args=[Intermediate(idx=25)])], Intermediate(idx=-4): [Op(name='tt.return', fn_call_name=None, args=[Intermediate(idx=27)])]} ``` whereas it should look like this (not the `Param(idx=0)` arguments of the `tt.call`): ``` {Intermediate(idx=25): [Op(name='tt.call', fn_call_name='_sum_combine__fp32_fp32__', args=[Param(idx=0), Param(idx=0)])], Intermediate(idx=27): [Op(name='tt.scan.return', fn_call_name=None, args=[Intermediate(idx=25)])], Intermediate(idx=-4): [Op(name='tt.return', fn_call_name=None, args=[Intermediate(idx=27)])]} ``` This is fixed in the PR. Test Plan: ``` $ python test/inductor/test_triton_kernels.py -k test_cumsum . ---------------------------------------------------------------------- Ran 1 test in 1.771s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121867 Approved by: https://github.com/oulgen
Author
Committer
Parents
Loading