[inductor] Persistent reductions (#92267)
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes.
Before:
```
$ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda
...
@reduction(
size_hints=[16, 32],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}
)
@triton.jit
def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16
rnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last')
_tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1)
tmp1 = tl.max(_tmp1, 1)[:, None]
_tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last')
tmp3 = tmp2 - tmp1
tmp4 = tl.exp(tmp3)
_tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5)
tmp5 = tl.sum(_tmp5, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last')
tmp7 = tmp6 - tmp1
tmp8 = tl.exp(tmp7)
tmp9 = tmp8 / tmp5
tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask)
```
After
```
$ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda
...
@persistent_reduction(
size_hints=[16, 32],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}
)
@triton.jit
def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16
rnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask)
tmp2 = tl.where(xmask & rmask, tmp0, float("-inf"))
tmp3 = tl.max(tmp2, 1)[:, None]
tmp4 = tmp0 - tmp3
tmp5 = tl.exp(tmp4)
tmp7 = tl.where(xmask & rmask, tmp5, 0)
tmp8 = tl.sum(tmp7, 1)[:, None]
tmp9 = tmp5 / tmp8
tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92267
Approved by: https://github.com/Chillee