pytorch
85e393ba - Fix for RNN/LSTM/GRU modules to work with stateless.functional_call() (#91111)

Commit
2 years ago
Fix for RNN/LSTM/GRU modules to work with stateless.functional_call() (#91111) Fixes #90500 The change here checks for parameter changes at the beginning of each `forward()` call; if the parameters are found to be different tensors than last time, `self._flat_weights` is re-initialized with the new values. Thus, swapping parameter values using `stateless.functional_call()` will re-initialize `self._flat_weights` during the `forward()` call, and the provided parameters will be used for module computation as expected. NB: There are still some changes needed for symbolic shapes to work with `nn.GRU` (will address in a follow-up PR). Pull Request resolved: https://github.com/pytorch/pytorch/pull/91111 Approved by: https://github.com/ezyang, https://github.com/albanD
Author
Committer
Parents
Loading