pytorch
495e1b4d - Add device_asserts before indirect loads and stores (#98590)

Commit
1 year ago
Add device_asserts before indirect loads and stores (#98590) This PR also adds a way to CSE statements (not only assignments). The tests follow the pattern from https://github.com/openai/triton/pull/1143 They take a fair amount of time to run (90s in my box). If we wanted to improve this, we could avoid testing the `ndim == 3` case. Changes like this one make me hope that we get to clean the amount of lowerings we have at some point... Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`: With `dynamic=False`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3") tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask) ``` With `dynamic=True`: ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask) ``` Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`: With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2` ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = 1 tmp2 = tmp0 + tmp1 tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2") tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask) ``` With `dynamic=True`: ```python tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)") ``` The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))` Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes ```python tmp0 = tl.load(in_ptr0 + (x1), xmask) tmp1 = tl.load(in_ptr1 + (x2), xmask) tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3") tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask) ``` Fixes https://github.com/pytorch/pytorch/issues/93538 Pull Request resolved: https://github.com/pytorch/pytorch/pull/98590 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading