pytorch
0485bf53 - Avoid saving pointwise intermediate to global memory if followed by a reduction (#93810)

Commit
1 year ago
Avoid saving pointwise intermediate to global memory if followed by a reduction (#93810) Should fix https://github.com/pytorch/pytorch/issues/91880 and maybe https://github.com/pytorch/pytorch/issues/91799 For this code: ``` @torch.compile def f(a, b): return (a-b).sum(dim=-1).amax(dim=-1) N = 2**14 K = 5 A = torch.randn(N, 1, K, device='cuda') B = torch.randn(1, N, K, device='cuda') bench(lambda: f(A, B), name=f"K={K}") print(f"peak Mem: {torch.cuda.max_memory_allocated()/1e9}GB") ``` Before my change, we generated (simplified versions) ``` def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): ... for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last') ... tmp18 = tmp14 + tmp17 tl.store(out_ptr0 + (r1 + (16384*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp18, rmask & xmask) _tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp19 = tl.load(out_ptr0 + (r1 + (16384*x0)), rmask & xmask, eviction_policy='evict_last') _tmp20 = tl.where(rmask & xmask & (_tmp20 < tmp19), tmp19, _tmp20) tmp20 = tl.max(_tmp20, 1)[:, None] tl.store(out_ptr1 + x0, tmp20, xmask) ``` and after ``` def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): ... _tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last') ... tmp18 = tmp14 + tmp17 _tmp19 = tl.where(rmask & xmask & (_tmp19 < tmp18), tmp18, _tmp19) tmp19 = tl.max(_tmp19, 1)[:, None] tl.store(out_ptr1 + x0, tmp19, xmask) ``` <details> <summary>full kernels here </summary> Before: ``` def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16384 rnumel = 16384 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex tmp0 = tl.load(in_ptr0 + (5*x0), xmask) tmp3 = tl.load(in_ptr0 + (1 + (5*x0)), xmask) tmp7 = tl.load(in_ptr0 + (2 + (5*x0)), xmask) tmp11 = tl.load(in_ptr0 + (3 + (5*x0)), xmask) tmp15 = tl.load(in_ptr0 + (4 + (5*x0)), xmask) for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last') tmp4 = tl.load(in_ptr1 + (1 + (5*r1)), rmask, eviction_policy='evict_last') tmp8 = tl.load(in_ptr1 + (2 + (5*r1)), rmask, eviction_policy='evict_last') tmp12 = tl.load(in_ptr1 + (3 + (5*r1)), rmask, eviction_policy='evict_last') tmp16 = tl.load(in_ptr1 + (4 + (5*r1)), rmask, eviction_policy='evict_last') tmp2 = tmp0 - tmp1 tmp5 = tmp3 - tmp4 tmp6 = tmp2 + tmp5 tmp9 = tmp7 - tmp8 tmp10 = tmp6 + tmp9 tmp13 = tmp11 - tmp12 tmp14 = tmp10 + tmp13 tmp17 = tmp15 - tmp16 tmp18 = tmp14 + tmp17 tl.store(out_ptr0 + (r1 + (16384*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp18, rmask & xmask) _tmp20 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp19 = tl.load(out_ptr0 + (r1 + (16384*x0)), rmask & xmask, eviction_policy='evict_last') _tmp20 = tl.where(rmask & xmask & (_tmp20 < tmp19), tmp19, _tmp20) tmp20 = tl.max(_tmp20, 1)[:, None] tl.store(out_ptr1 + x0, tmp20, xmask) ``` After: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16384 rnumel = 16384 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex tmp0 = tl.load(in_ptr0 + (5*x0), xmask) tmp3 = tl.load(in_ptr0 + (1 + (5*x0)), xmask) tmp7 = tl.load(in_ptr0 + (2 + (5*x0)), xmask) tmp11 = tl.load(in_ptr0 + (3 + (5*x0)), xmask) tmp15 = tl.load(in_ptr0 + (4 + (5*x0)), xmask) _tmp19 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp1 = tl.load(in_ptr1 + (5*r1), rmask, eviction_policy='evict_last') tmp4 = tl.load(in_ptr1 + (1 + (5*r1)), rmask, eviction_policy='evict_last') tmp8 = tl.load(in_ptr1 + (2 + (5*r1)), rmask, eviction_policy='evict_last') tmp12 = tl.load(in_ptr1 + (3 + (5*r1)), rmask, eviction_policy='evict_last') tmp16 = tl.load(in_ptr1 + (4 + (5*r1)), rmask, eviction_policy='evict_last') tmp2 = tmp0 - tmp1 tmp5 = tmp3 - tmp4 tmp6 = tmp2 + tmp5 tmp9 = tmp7 - tmp8 tmp10 = tmp6 + tmp9 tmp13 = tmp11 - tmp12 tmp14 = tmp10 + tmp13 tmp17 = tmp15 - tmp16 tmp18 = tmp14 + tmp17 _tmp19 = tl.where(rmask & xmask & (_tmp19 < tmp18), tmp18, _tmp19) tmp19 = tl.max(_tmp19, 1)[:, None] tl.store(out_ptr1 + x0, tmp19, xmask) ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/93810 Approved by: https://github.com/ngimel, https://github.com/jansel
Author
Committer
Parents
Loading