Support LSTM with FP16 weight (#23291)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23291
This diff implements LSTM with FP16 weights based on FBGEMM.
At a high level, here are the steps:
1. Quantize and pack weight in every layer of LSTM
2. Pass weights from step 1 to the ATen `quantized_lstm` function which does matrix multiplication with FP16 weight. The following code shows the dtype of each variable used in MM:
Y = X * W + B
(fp32, fp32, fp16, fp32)
Reviewed By: jianyuh
Differential Revision: D16389595
fbshipit-source-id: c26ae4e153c667a941f4af64e9d07fc251403cee