[Quant] Allow setting fixed qparams for inner LSTM ops (#88456)
Summary: In both eager and FX graph mode quantization,
`torch.ao.nn.quantizable.LSTM` is used as an observed custom module,
which is responsible for inserting its own observers. By default,
the user specifies a single QConfig for the custom module (either
through QConfigMapping or by setting the "qconfig" attribute"),
and all inner ops will [inherit this
QConfig](https://github.com/pytorch/pytorch/blob/dc00bb51b8d370bf3891f0edb2c6e0c2914e329a/torch/ao/nn/quantizable/modules/rnn.py#L366-L378)
and use the same observer/fake_quantize constructors.
Today, users who wish to override this behavior must extend
`torch.ao.nn.quantizable.LSTM` and write a lot of custom code
to manually assign the QConfigs to the inner ops. This commit
alleviates this burden on the user by providing a helper function
to assign QConfigs with custom observers. An example use case of
this is providing a reference implementation for a backend kernel
that hardcodes qparams for efficiency.
Example usage:
```
import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.fx.custom_config import (
PrepareCustomConfig,
ConvertCustomConfig,
)
class MyModel(torch.nn.Module):
...
class UserLSTM(torch.ao.nn.quantizable.LSTM):
@classmethod
def from_float(cls, other):
assert isinstance(other, cls._FLOAT_MODULE)
linear_output_obs_ctr = FixedQParamsObserver.with_args(
scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
sigmoid_obs_ctr = FixedQParamsObserver.with_args(
scale=2 ** -16, zero_point=0, dtype=torch.qint32)
tanh_obs_ctr = FixedQParamsObserver.with_args(
scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
cell_state_obs_ctr = FixedQParamsObserver.with_args(
scale=2 ** -11, zero_point=0, dtype=torch.qint32)
hidden_state_obs_ctr = FixedQParamsObserver.with_args(
scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
return torch.ao.quantization.utils._get_lstm_with_individually_observed_parts(
float_lstm=other,
linear_output_obs_ctr=linear_output_obs_ctr,
sigmoid_obs_ctr=sigmoid_obs_ctr,
tanh_obs_ctr=tanh_obs_ctr,
cell_state_obs_ctr=cell_state_obs_ctr,
hidden_state_obs_ctr=hidden_state_obs_ctr,
)
qconfig_mapping = get_default_qconfig_mapping()
example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
prepare_custom_config = PrepareCustomConfig() \
.set_float_to_observed_mapping(torch.nn.LSTM, UserLSTM)
convert_custom_config = ConvertCustomConfig() \
.set_observed_to_quantized_mapping(UserLSTM, torch.ao.nn.quantized.LSTM)
model = MyModel()
model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config)
model(*example_inputs) # calibrate
model = convert_fx(model, convert_custom_config=convert_custom_config)
model(*example_inputs)
```
Test Plan:
python test/test_quantization.py TestQuantizeFx.test_static_lstm_with_custom_fixed_qparams
Reviewers: jerryzh168, vkuzo
Subscribers: jerryzh168, vkuzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88456
Approved by: https://github.com/jerryzh168, https://github.com/vkuzo