pytorch
aab55d6d - [Quant] Remove all the dequant nodes when the ref module has multi input args (#90157)

Commit
1 year ago
[Quant] Remove all the dequant nodes when the ref module has multi input args (#90157) **Summary**: When converting a ref module into a quant module, `_lower_static_weighted_ref_module` pass assumes the `ref_node` only has 1 input node, and only remove the first `dequant` node. We add a check in this PR to ensure this is the case for `_lower_static_weighted_ref_module` pass. **Test Plan**: We only add a check in this PR, there is no new added test case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90157 Approved by: https://github.com/Xia-Weiwen, https://github.com/jgong5, https://github.com/jerryzh168
Committer
Parents
Loading