pytorch
94d306fd - [inductor] Stop using `x + tl.zeros(...)` in generated triton (#100163)

Commit
1 year ago
[inductor] Stop using `x + tl.zeros(...)` in generated triton (#100163) For reductions, this changes the accumulator ```python _tmp2 = tl.zeros([XBLOCK, RBLOCK], tl.int8) + -128 ``` to ```python _tmp2 = tl.full([XBLOCK, RBLOCK], -128, tl.int32) ``` which is equivalent since addition does type promotion from `int8` to `int32` For constant indexing, this changes ```python tl.store(in_out_ptr0 + (0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp4, None) ``` to ```python tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None) ``` For variable indexing, this changes ```python tl.store(out_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), tmp1, None) ``` to ```python tl.store(out_ptr0 + (tl.broadcast_to(x0, [XBLOCK])), tmp1, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/100163 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading