pytorch
592481a5 - [fx][const_fold] Refactor to use base split module to simplify, and correctly handle non-single-Tensor outputs (#65933)

Commit
4 years ago
[fx][const_fold] Refactor to use base split module to simplify, and correctly handle non-single-Tensor outputs (#65933) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65933 We use `split_module` to split the input model that we want to const fold into const and non-const subgraphs. Previously we were taking the non-const graph and trying to hack it back into the same signature as the input model. However this was complex/buggy. Instead, refactor to just keep using the base split module that contains both const and non-const graphs. This means we: - Inline the non-const graph into the split module - Remove the const graph from the module and replace it with a getattr that will be run to insert that attr when we `run_folding` Test Plan: Added test coverage to cover newly supported folding, and updated other tests for new strategy. Reviewed By: yinghai Differential Revision: D31293307 fbshipit-source-id: 6e283a8c7222cf07b14e30e74dffc8ae5ee8b55f
Author
Parents
Loading