pytorch
6d86dc53 - dbr quant: store auto_quant_state on the top level model (#72934)

Commit
2 years ago
dbr quant: store auto_quant_state on the top level model (#72934) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72934 Before this PR, DBR quantization had a limitation on handling user code which iterates over all module children. For example, imagine a forward function such as ``` def forward(self, x): for module in self: x = module(x) return x ``` Before this PR, this code would break with DBR quantization, because we attach `AutoQuantizationState` objects to each child, and those objects live in the child's module hierarchy and will appear in these kinds of iterations, changing the meaning of the user program. This PR reduces the scope of this problem to just the top level module. Instead of attaching `AutoQuantizationState` objects to each child, we register them in a map on the parent. Here is a before and after: ``` // toy model model |--> child1 // toy model with AutoQuantizationState objects, before this PR model |--> child1 | |--> _auto_quant_state |--> _auto_quant_state // toy model with AutoQuantizationState objects, after this PR model |--> child1 |--> _fqn_to_auto_quant_state_map |--> ( ) --> _auto_quant_state // of `model` |--> (child1) --> _auto_quant_state // of `model.child1` ``` Note: `child1._auto_quant_state` works as before for convenience, but the `child1` object now stores a soft link to its `_auto_quant_state` instead of properly registering it in its module hierarchy. This is somewhat hacky. If we need to improve this in the future, we could remove this soft link and refactor the code to call the FQN map instead. Note: if the top level module iterates over its children, things will still be broken. This is less likely, and we will recommend that the user work around this by wrapping their model, or checking for the `AutoQuantizationStateModuleDict` type in their iteration loop. The impact of this change should be an improvement of coverage of user models. In fact, we expect this to drive our coverage of torchbenchmark models from 89% to 100%. Test Plan: ``` // previously disabled test cases with user code iterating // over module children are now enabled, with wrappers python test/test_quantization.py -k test_module_calls_items python test/test_quantization.py -k test_vovnet_sequential ``` Reviewed By: dzdang Differential Revision: D34281074 Pulled By: vkuzo fbshipit-source-id: 0e25fc1ec529c47f72478a1875fe43219feac6b1 (cherry picked from commit 4008f899671f643f5a54c311254af3f68ae22e2e)
Author
Committer
Parents
Loading