[inductor] Avoid re-computing mean in lowering for aten.var_mean (#94139)
The current lowering results in the mean being computed twice. In the following
snippet, both `tmp1` and `tmp8` are the sum of `in_ptr0`:
```python
def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
# ...
_tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last')
_tmp1 = tl.where(rmask, _tmp1 + tmp0, _tmp1)
tmp1 = tl.sum(_tmp1, 1)[:, None]
_tmp7 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
_tmp8 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp2 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last')
tmp3 = 100.0
tmp4 = tmp1 / tmp3
tmp5 = tmp2 - tmp4
tmp6 = tmp5 * tmp5
_tmp7 = tl.where(rmask, _tmp7 + tmp6, _tmp7)
_tmp8 = tl.where(rmask, _tmp8 + tmp2, _tmp8)
tmp7 = tl.sum(_tmp7, 1)[:, None]
tmp8 = tl.sum(_tmp8, 1)[:, None]
# ...
```
After this change, the mean is computed only once:
```python
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last')
_tmp1 = tl.where(rmask, _tmp1 + tmp0, _tmp1)
tmp1 = tl.sum(_tmp1, 1)[:, None]
tmp2 = 100.0
tmp3 = tmp1 / tmp2
tl.store(in_out_ptr0 + (0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp3, None)
_tmp7 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp4 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_last')
tmp5 = tmp4 - tmp3
tmp6 = tmp5 * tmp5
_tmp7 = tl.where(rmask, _tmp7 + tmp6, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94139
Approved by: https://github.com/lezcano, https://github.com/jansel