pytorch
d8b971ed - Fixes for partitioner with symbolic shapes (#86425)

Commit
2 years ago
Fixes for partitioner with symbolic shapes (#86425) - supports saving symint (and symfloat..) values between fw/bwd, using sketchy logic that probably needs to be improved but seems to work so far - sets a correct weight=1 for sym nodes for cost purposes - lets user functions return symints/floats (but if the same symfloat is saved for backward, that gets duplicated annoyingly) - makes partitioning decisions based on observed trace-time sizes without guarding! (this is sketchy, but it isn't clear that it will lead to bad partitioning choices either) - improves infra for tracking symint-family of types: is_sym_node() and _py_sym_types Pull Request resolved: https://github.com/pytorch/pytorch/pull/86425 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading