pytorch
bb4174d2 - [inductor] Enable CSE on masked loads (#98303)

Commit
1 year ago
[inductor] Enable CSE on masked loads (#98303) Currently the `TritonKernel.mask_loads` context manager calls `swap_buffers` which creates a new CSE context. So, code generated in different mask contexts cannot be CSE'd even if their masks are the same. This fixes the issue by not calling `swap_buffers` and instead having `load` manually check if a `"tmp"` name appears in the mask meaning the load needs to be generated in the compute buffer. Currently, simple programs involving padding will result in duplcate masked loads. e.g. the generated code for ```python def forward(): a = torch.nn.functional.pad(x, (0, 1)) return a + a ``` contains the lines ```python tmp3 = tl.load(in_ptr0 + (x1 + tl.zeros([XBLOCK], tl.int32)), tmp2 & xmask, other=0) tmp4 = tl.where(tmp2, tmp3, 0.0) tmp5 = tl.load(in_ptr0 + (x1 + tl.zeros([XBLOCK], tl.int32)), tmp2 & xmask, other=0) tmp6 = tl.where(tmp2, tmp5, 0.0) ``` With this change, the duplicates are removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98303 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading