pytorch
a595d06c - [inductor] Avoid re-computing mean in lowering for aten.var_mean (#94139)

Commit
1 year ago
[inductor] Avoid re-computing mean in lowering for aten.var_mean (#94139) The current lowering results in the mean being computed twice. In the following snippet, both `tmp1` and `tmp8` are the sum of `in_ptr0`: ```python def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): # ... _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r0 = rindex tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last') _tmp1 = tl.where(rmask, _tmp1 + tmp0, _tmp1) tmp1 = tl.sum(_tmp1, 1)[:, None] _tmp7 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 _tmp8 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r0 = rindex tmp2 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last') tmp3 = 100.0 tmp4 = tmp1 / tmp3 tmp5 = tmp2 - tmp4 tmp6 = tmp5 * tmp5 _tmp7 = tl.where(rmask, _tmp7 + tmp6, _tmp7) _tmp8 = tl.where(rmask, _tmp8 + tmp2, _tmp8) tmp7 = tl.sum(_tmp7, 1)[:, None] tmp8 = tl.sum(_tmp8, 1)[:, None] # ... ``` After this change, the mean is computed only once: ```python for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r0 = rindex tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last') _tmp1 = tl.where(rmask, _tmp1 + tmp0, _tmp1) tmp1 = tl.sum(_tmp1, 1)[:, None] tmp2 = 100.0 tmp3 = tmp1 / tmp2 tl.store(in_out_ptr0 + (0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp3, None) _tmp7 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r0 = rindex tmp4 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last') tmp5 = tmp4 - tmp3 tmp6 = tmp5 * tmp5 _tmp7 = tl.where(rmask, _tmp7 + tmp6, _tmp7) tmp7 = tl.sum(_tmp7, 1)[:, None] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94139 Approved by: https://github.com/lezcano, https://github.com/jansel
Author
Committer
Parents
Loading