pytorch
034f2b4d - [Quant][fx] Enable FX static quantization for LSTM (#85068)

Commit
3 years ago
[Quant][fx] Enable FX static quantization for LSTM (#85068) **Summary:** This commit enables the custom module LSTM path for FX graph mode static quantization. This has the same flow as eager mode, which was already previously supported: ``` torch.nn.LSTM | (prepare_fx) v torch.ao.nn.quantizable.LSTM | (convert_fx) v torch.ao.nn.quantized.LSTM ``` The main reason why custom module LSTM is not supported in FX graph mode quantization today is because its inputs and outputs are nested tuples, and existing constructs such as observers, "quantize" nodes, and "dequantize" nodes do not understand how to handle complex structures. Note that the approach taken in this commit is only intended to be a short-term solution highly tailored to the input and output formats of custom module LSTM. In the future, for the longer-term solution, we should design a more general QConfig that allows users to specify complex input and output formats, and enable FX graph mode quantization to understand arbitrary nested structures and automatically infer how to transform the graph accordingly. **Context:** Today, in FX graph mode static quantization, custom modules are assumed to have quantized inputs and quantized outputs, with the exact dtypes derived from the associated QConfig (default quint8). Since custom modules are currently not handled through the reference model flow, their observer replacement logic are a little different from normal operators: ``` # (1) Original model input -> custom_module -> output # (2) Observed model (after prepare) input -> obs0 -> custom_module -> obs1 -> output # (3) Quantized model (after convert) input -> quant -> quantized_custom_module -> dequant -> output ``` In the last step, input observers are replaced with "quantize" and output observers are replaced with "dequantize", in contrast to other non-custom-module patterns where observers are replaced with "quantize-dequantize" pairs instead. Note that, conceptually, the output observer `obs1` is really just a DeQuantStub, since no observation is actually needed. **Custom module LSTM:** The reason why custom module LSTM cannot be handled in the same way is because, unlike other custom modules, its inputs and outputs are nested tuples instead of single tensors. This is how the existing custom module code would try to handle LSTMs: ``` # (1) Original model # input format: (input, (hidden0, hidden1)) # output format: (output, (hidden0, hidden1)) input -> lstm -> output hidden0 -/ \-> hidden0 hidden1 -/ \-> hidden1 # (2) Observed model (after prepare) input -> obs0 -> lstm -> obs1 # fails hidden0 -/ # missing observer hidden1 -/ # missing observer ``` However, this fails today because 1) we assume there is only one input to the custom module, and so we never end up quantizing `hidden0` and `hidden1`, and 2) the output observer `obs1` is fed a tuple, which it does not understand how to handle. **Short-term fix:** This commit addresses the above by specifically handling the input and output structures used by custom module LSTM. For the inputs, we manually insert observers for `hidden0` and `hidden1` to ensure all input tensors are quantized. For the outputs, we split the tuple into its internal nodes, attach a DeQuantStub to each node, and recombine these DeQuantStubs according to the original structure. Finally, we must also reroute consumers of the original LSTM tuple (and its internal nodes, e.g. `lstm[0]`) to these DeQuantStubs: ``` # (1) Original model input -> lstm -> output -> linear0 hidden0 -/ \-> hidden0 -> linear1 hidden1 -/ \-> hidden1 -> linear2 # (2) Observed model (after prepare) input -> obs0 -> lstm -> output -> dqstub -> linear0 -> obs3 hidden0 -> obs1 -/ \-> hidden0 -> dqstub -> linear1 -> obs4 hidden1 -> obs2 -/ \-> hidden1 -> dqstub -> linear2 -> obs5 # (3) Reference model (after convert) input -> quant -> qlstm -> output -> dequant -> linear0 -> quant -> dequant hidden0 -> quant -/ \-> hidden0 -> dequant -> linear1 -> quant -> dequant hidden1 -> quant -/ \-> hidden1 -> dequant -> linear2 -> quant -> dequant # (4) Quantized model (after lowering) input -> quant -> qlstm -> output -> quantized_linear0 -> dequant hidden0 -> quant -/ \-> hidden0 -> quantized_linear1 -> dequant hidden1 -> quant -/ \-> hidden1 -> quantized_linear2 -> dequant ``` Note that we choose to insert DeQuantStubs here instead of observers because these will ultimately be replaced by "dequantize" nodes. This matches the general custom module behavior, where output observers are replaced only with "dequantize" nodes (as opposed to the normal "quantize-dequantize" pair), since custom module outputs are assumed to already be quantized. Using DeQuantStubs instead of observers also simplifies the "dequantize" insertion logic. In the future, we should use DeQuantStubs in place of output observers for custom modules in general. **Test plan:** python test/test_quantization.py TestQuantizeFx.test_static_lstm python test/test_quantization.py TestQuantizeFx.test_static_lstm_consume_tuple **Reviewers:** jerryzh168, vkuzo **Subscribers:** jerryzh168, vkuzo Pull Request resolved: https://github.com/pytorch/pytorch/pull/85068 Approved by: https://github.com/jerryzh168
Author
Committer
Parents
Loading