pytorch
05b679ce - [inductor] don't match indirect indexing in fusion (#96273)

Commit
1 year ago
[inductor] don't match indirect indexing in fusion (#96273) Fixes #96064 When deciding whether to fuse nodes, we match indexing like `c0 + 5 * tmp0`, but `tmp0` in the different nodes can refer to totally different values. Even when `tmp0` is the same (like in the added test) inductor still generates wrongly ordered loads and stores (loads come before stores), so better just disable this fusion altogether. We should fix wrong order also: ``` @pointwise(size_hints=[8], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}) @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr): xnumel = 5 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0_load = tl.load(in_ptr0 + (0)) tmp0 = tl.broadcast_to(tmp0_load, [XBLOCK]) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tl.load(out_ptr0 + (x0 + (5*tmp0)), xmask) tl.store(out_ptr0 + (x0 + (5*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) tl.store(out_ptr1 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) ``` Note: we are loading from `out_ptr0` here (that shouldn't happen), we are loading from it before storing to it. After this PR, the kernel above is split in 2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96273 Approved by: https://github.com/jansel
Author
Natalia Gimelshein
Committer
Parents
Loading