pytorch
97889fa1 - simplify indexing expression before trying to determine strides (#98783)

Commit
1 year ago
simplify indexing expression before trying to determine strides (#98783) This fixes a few failing cases where we fail to compute stride_hint for an indexing expression with ModularIndexing When can size_hint error out? It shouldn't happen when we are getting regular size hints for expressions where free vars are in ShapeEnv. But this is not the case when we try to recover strides from indexing expressions (which is what stride_hint is for). Suppose you have an indexing expression that looks like ``` 289*d0 + ModularIndexing(7399*d1 + d2, 1, 17) + 17*ModularIndexing(7399*d1 + d2, 17, 17) + 46240*ModularIndexing(7399*d1 + d2, 289, 128) ``` and want to understand its stride wrt to variable `d1`. Let's ignore for a moment that stride for ModularIndexing is not well defined, it'll become negative around modulo divisor value, but even without that, the way we usually compute stride is we substitute `0` and `1` for `d1` and compute difference in indexing expression with those substitutions - this is our stride. But for the expression above, the difference would result in an expression that still has free variable `d2` that we don't have a substitution for. The fix that this PR makes is it expands stride computation to substitute not only `0` and `1` for the variable we are computing a stride for, but also `0` for other variables in the indexing expression (`support_vars`). Note that computing strides in `stride_hints` is a performance optimization that we use to reorder dimensions or make split decisions for split reduction. If it fails, it's not a hard error - we may incorrectly apply reordering by it won't affect correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98783 Approved by: https://github.com/ezyang, https://github.com/voznesenskym
Author
Natalia Gimelshein
Committer
Parents
Loading