Avoid saving pointwise intermediate to global memory if followed by a reduction (#93810)
Should fix https://github.com/pytorch/pytorch/issues/91880 and maybe https://github.com/pytorch/pytorch/issues/91799
For this code:
```
@torch.compile
def f(a, b):
return (a-b).sum(dim=-1).amax(dim=-1)
N = 2**14
K = 5
A = torch.randn(N, 1, K, device='cuda')
B = torch.randn(1, N, K, device='cuda')
bench(lambda: f(A, B), name=f"K={K}")
print(f"peak Mem: {torch.cuda.max_memory_allocated()/1e9}GB")
```
Before my change, we generated (simplified versions)
```
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
...
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
...
tmp18 = tmp14 + tmp17
tl.store(out_ptr0 + (r1 + (16384*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp18, rmask & xmask)
_tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp19 = tl.load(out_ptr0 + (r1 + (16384*x0)), rmask & xmask, eviction_policy='evict_last')
_tmp20 = tl.where(rmask & xmask & (_tmp20 < tmp19), tmp19, _tmp20)
tmp20 = tl.max(_tmp20, 1)[:, None]
tl.store(out_ptr1 + x0, tmp20, xmask)
```
and after
```
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
...
_tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
...
tmp18 = tmp14 + tmp17
_tmp19 = tl.where(rmask & xmask & (_tmp19 < tmp18), tmp18, _tmp19)
tmp19 = tl.max(_tmp19, 1)[:, None]
tl.store(out_ptr1 + x0, tmp19, xmask)
```
<details>
<summary>full kernels here
</summary>
Before:
```
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16384
rnumel = 16384
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (5*x0), xmask)
tmp3 = tl.load(in_ptr0 + (1 + (5*x0)), xmask)
tmp7 = tl.load(in_ptr0 + (2 + (5*x0)), xmask)
tmp11 = tl.load(in_ptr0 + (3 + (5*x0)), xmask)
tmp15 = tl.load(in_ptr0 + (4 + (5*x0)), xmask)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr1 + (1 + (5*r1)), rmask, eviction_policy='evict_last')
tmp8 = tl.load(in_ptr1 + (2 + (5*r1)), rmask, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr1 + (3 + (5*r1)), rmask, eviction_policy='evict_last')
tmp16 = tl.load(in_ptr1 + (4 + (5*r1)), rmask, eviction_policy='evict_last')
tmp2 = tmp0 - tmp1
tmp5 = tmp3 - tmp4
tmp6 = tmp2 + tmp5
tmp9 = tmp7 - tmp8
tmp10 = tmp6 + tmp9
tmp13 = tmp11 - tmp12
tmp14 = tmp10 + tmp13
tmp17 = tmp15 - tmp16
tmp18 = tmp14 + tmp17
tl.store(out_ptr0 + (r1 + (16384*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp18, rmask & xmask)
_tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp19 = tl.load(out_ptr0 + (r1 + (16384*x0)), rmask & xmask, eviction_policy='evict_last')
_tmp20 = tl.where(rmask & xmask & (_tmp20 < tmp19), tmp19, _tmp20)
tmp20 = tl.max(_tmp20, 1)[:, None]
tl.store(out_ptr1 + x0, tmp20, xmask)
```
After:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16384
rnumel = 16384
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp0 = tl.load(in_ptr0 + (5*x0), xmask)
tmp3 = tl.load(in_ptr0 + (1 + (5*x0)), xmask)
tmp7 = tl.load(in_ptr0 + (2 + (5*x0)), xmask)
tmp11 = tl.load(in_ptr0 + (3 + (5*x0)), xmask)
tmp15 = tl.load(in_ptr0 + (4 + (5*x0)), xmask)
_tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr1 + (1 + (5*r1)), rmask, eviction_policy='evict_last')
tmp8 = tl.load(in_ptr1 + (2 + (5*r1)), rmask, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr1 + (3 + (5*r1)), rmask, eviction_policy='evict_last')
tmp16 = tl.load(in_ptr1 + (4 + (5*r1)), rmask, eviction_policy='evict_last')
tmp2 = tmp0 - tmp1
tmp5 = tmp3 - tmp4
tmp6 = tmp2 + tmp5
tmp9 = tmp7 - tmp8
tmp10 = tmp6 + tmp9
tmp13 = tmp11 - tmp12
tmp14 = tmp10 + tmp13
tmp17 = tmp15 - tmp16
tmp18 = tmp14 + tmp17
_tmp19 = tl.where(rmask & xmask & (_tmp19 < tmp18), tmp18, _tmp19)
tmp19 = tl.max(_tmp19, 1)[:, None]
tl.store(out_ptr1 + x0, tmp19, xmask)
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93810
Approved by: https://github.com/ngimel, https://github.com/jansel