[inductor] Enable CSE on masked loads (#98303)
Currently the `TritonKernel.mask_loads` context manager calls
`swap_buffers` which creates a new CSE context. So, code generated in
different mask contexts cannot be CSE'd even if their masks are the
same. This fixes the issue by not calling `swap_buffers` and instead
having `load` manually check if a `"tmp"` name appears in the mask
meaning the load needs to be generated in the compute buffer.
Currently, simple programs involving padding will result in duplcate
masked loads. e.g. the generated code for
```python
def forward():
a = torch.nn.functional.pad(x, (0, 1))
return a + a
```
contains the lines
```python
tmp3 = tl.load(in_ptr0 + (x1 + tl.zeros([XBLOCK], tl.int32)), tmp2 & xmask, other=0)
tmp4 = tl.where(tmp2, tmp3, 0.0)
tmp5 = tl.load(in_ptr0 + (x1 + tl.zeros([XBLOCK], tl.int32)), tmp2 & xmask, other=0)
tmp6 = tl.where(tmp2, tmp5, 0.0)
```
With this change, the duplicates are removed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98303
Approved by: https://github.com/ngimel