pytorch
a8fdfb4b - [inductor] Persistent reductions (#92267)

Commit
1 year ago
[inductor] Persistent reductions (#92267) This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... @reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) @triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... @persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) @triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/92267 Approved by: https://github.com/Chillee
Author
Committer
Parents
Loading