[inductor] hoist symbolic padding expressions (#97099)
Towards fixing pnasnet5large, see #96709. The generated kernel looks much better
```
@pointwise(size_hints=[1048576], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 6), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // ks0) % ks0
x0 = xindex % ks0
x2 = (xindex // ks3)
x4 = xindex
tmp0 = x1 + ((-1)*ks1)
tmp1 = 0
tmp2 = tmp0 >= tmp1
tmp3 = ks2
tmp4 = tmp0 < tmp3
tmp5 = x0 + ((-1)*ks1)
tmp6 = tmp5 >= tmp1
tmp7 = tmp5 < tmp3
tmp8 = tmp2 & tmp4
tmp9 = tmp8 & tmp6
tmp10 = tmp9 & tmp7
tmp11 = tl.load(in_ptr0 + (x0 + ((-1)*ks1) + (ks2*x1) + (x2*(ks2*ks2)) + ((-1)*ks1*ks2) + tl.zeros([XBLOCK], tl.int32)), tmp10 & xmask, other=0)
tmp12 = tl.where(tmp10, tmp11, 0.0)
tl.store(out_ptr0 + (x4 + tl.zeros([XBLOCK], tl.int32)), tmp12, xmask)
```
Interestingly, removing `expand` in in index `simplify` function makes `load` expression a little bit better, but `store` fails to simplify to flat store in this case, so I'm leaving `expand` in.
Full pnasnet still chokes on `ceiling` in batch_norm kernels, additionally, it looks like shape propagation goofs in inductor and generates overly complicated expressions, we should switch to meta data from fx graph.
I'm still not adding `ceil` print to triton, because we should be able to hoist all indexing expression (and just printing ceil without converting to int64 doesn't work)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97099
Approved by: https://github.com/jansel
Author
Natalia Gimelshein