pytorch
6848e0da - Fix RNN modules with inputs shapes containing-0 in CUDA (#71696)

Commit
2 years ago
Fix RNN modules with inputs shapes containing-0 in CUDA (#71696) Summary: We found a discrepancy between cpu & CUDA when using RNN modules where input shapes containing 0s would cause an invalid configuration argument error in CUDA (kernel grid size is 0), while returning a valid tensor in CPU cases. A reproducer: ``` import torch x = torch.zeros((5, 0, 3)).cuda() gru = torch.nn.GRU(input_size=3, hidden_size=4).to("cuda") gru(x) ``` Run with `CUDA_LAUNCH_BLOCKING=1` set. cc ngimel albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/71696 Reviewed By: mikaylagawarecki Differential Revision: D33743674 Pulled By: ngimel fbshipit-source-id: e9334175d10969fdf1f9c63985910d944bbd26e7 (cherry picked from commit 70838ba69bbfd1b39f6c208f9dbefaad3f11d47e)
Author
Emilio Castillo
Committer
Parents
Loading