pytorch
87ab665b - Fix SyncBatchNorm for empty inputs (#74944)

Commit
2 years ago
Fix SyncBatchNorm for empty inputs (#74944) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74944 fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: D35273409 D35273409 Test Plan: Imported from OSS Reviewed By: datumbox Pulled By: mrshenli fbshipit-source-id: 1cee51eea866773c329b3fbf5da2be8a5fee6f0f (cherry picked from commit f8e2a2357240ebe7b7a058047d376a5300bdeda9)
Author
Committer
Parents
Loading