Fix redudant kernel generations (#102104)
## Issue description
The PR https://github.com/pytorch/pytorch/pull/100064 introduces a new RNG operation process. However, it causes every `randint` to load a separate random seed by default. TorchInductor generates a buffer to store all necessary random seeds and places the offsets as constant values in the subsequent compute buffers. In ir_pre_fusion generated by TorchInductor, some buffers only differ by one line, which is the load random seed with the corresponding offset. Subsequently, the codegen generates Triton kernels following the same rule. Finally, in the output_code.py, some Triton kernels only differ by one line, meaning that redundant kernels are being generated.
## Solution
This PR captures the seed offset and adds it to the existing `self.sizevars` structure. It generates variable names as placeholders, allowing the code wrapper to pass the offset as an argument to the kernels. I've also modified the divisible_by_16 check to exclude this argument.
This PR reduces the number of generated kernels from 50 to 17 for BertForMaskedLM forward.
According to tests on my own environment, the compilation time of attention_is_all_you_need_pytorch has been reduced from 94s to 66s. The speedup remains largely unchanged, at 1.37X.
The following is a comparison for a simple example.
Before:
```
triton_poi_fused_0 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
...
tmp0 = tl.load(in_ptr0 + 0)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
triton_poi_fused_1 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
...
tmp0 = tl.load(in_ptr0 + 1)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
...''')
def call(args):
triton_poi_fused_0.run(buf0, buf1, 1024, grid=grid(1024), stream=stream0)
triton_poi_fused_1.run(buf0, buf2, 1024, grid=grid(1024), stream=stream0)
```
After:
```
triton_poi_fused_0 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, load_seed_offset, xnumel, XBLOCK : tl.constexpr):
...
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
....
def call(args):
triton_poi_fused_0.run(buf0, buf1, 0, 1024, grid=grid(1024), stream=stream0)
triton_poi_fused_0.run(buf0, buf2, 1, 1024, grid=grid(1024), stream=stream0)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102104
Approved by: https://github.com/jansel, https://github.com/ngimel