pytorch
ca845084 - compute dynamic tensor shapes for indexing on the host (#93872)

Commit
1 year ago
compute dynamic tensor shapes for indexing on the host (#93872) Hoists computation of some shapes used in triton kernel indexing to the host, so resulting triton code is ``` x1 = (xindex // pks0) % 64 ``` instead of ``` x1 = (xindex // (1 + (((((-1) + ks0) // 4))*((((-1) + ks0) // 4))) + (2*((((-1) + ks0) // 4))))) % 64 ``` with `pks0` arg computed on the host ``` ps0 = (1 + ((((-1) + s2) // 4)))*(1 + ((((-1) + s2) // 4))) ``` It doesn't work yet for indexing expressions that are directly in the `load` statement, e.g. ``` tmp0 = tl.load(in_ptr0 + (r1 + x0 + (x0*(((((-1) + ks0) // 32))*((((-1) + ks0) // 32)))) + (2*x0*((((-1) + ks0) // 32)))), rmask & xmask, eviction_policy='evict_last').to(tl.float32) ``` Unfortunately, `unet` which is one of the examples failing with floor does the latter: ``` tmp1 = ((-1)*(1/(((-1) + (floor(2.0*(ks0//16))))))) + ((1/(((-1) + (floor(2.0*(ks0//16))))))*(ks0 // 16)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93872 Approved by: https://github.com/jansel
Author
Natalia Gimelshein
Committer
Parents
Loading