pytorch
1884d7fb - Avoid CPU Sync in SyncBatchNorm When Capturing CUDA Graphs

Commit
3 years ago
Avoid CPU Sync in SyncBatchNorm When Capturing CUDA Graphs We recently updated `SyncBatchNorm` to support empty input batches. The new code removes stats from ranks with empty inputs. However, this change breaks CUDA graph capture as it forces CPU sync. This commit uses `is_current_stream_capturing()` to guard the new code path, and only run the new code when not capturing CUA Graphs. To support empty inputs with CUDA graph capturing, we might need to update CUDA kernels for `batch_norm_backward_elemt` and `batch_norm_gather_stats_with_counts`. See #78656. Fixes #78549 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78666 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading