pytorch
6748fbd3 - Remove `MultiheadAttention` weights from constants (#39768)

Commit
4 years ago
Remove `MultiheadAttention` weights from constants (#39768) Summary: The weights of the `MultiheadAttention` were incorrectly listed as constants, which produced warnings when converting to a TorchScript module. ```py import torch import torch.nn as nn multihead_attn = nn.MultiheadAttention(256, 4) torch.jit.script(multihead_attn) ``` Warnings: ``` /home/michael/.local/lib/python3.8/site-packages/torch/jit/_recursive.py:151: UserWarning: 'q_proj_weight' was found in ScriptModule constants, but it is a non-constant parameter. Consider removing it. warnings.warn("'{}' was found in ScriptModule constants, " /home/michael/.local/lib/python3.8/site-packages/torch/jit/_recursive.py:151: UserWarning: 'k_proj_weight' was found in ScriptModule constants, but it is a non-constant parameter. Consider removing it. warnings.warn("'{}' was found in ScriptModule constants, " /home/michael/.local/lib/python3.8/site-packages/torch/jit/_recursive.py:151: UserWarning: 'v_proj_weight' was found in ScriptModule constants, but it is a non-constant parameter. Consider removing it. warnings.warn("'{}' was found in ScriptModule constants, " /home/michael/.local/lib/python3.8/site-packages/torch/jit/_recursive.py:151: UserWarning: 'in_proj_weight' was found in ScriptModule constants, but it is a non-constant parameter. Consider removing it. warnings.warn("'{}' was found in ScriptModule constants, " ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/39768 Reviewed By: zhangguanheng66 Differential Revision: D21977032 Pulled By: ngimel fbshipit-source-id: c2c3d0605a51324a9541f5a2caca7ab7a518dc00
Author
Parents
Loading