pytorch
2652da29 - Avoid CPU Sync in SyncBatchNorm When Capturing CUDA Graphs (#78810)

Commit
2 years ago
Avoid CPU Sync in SyncBatchNorm When Capturing CUDA Graphs (#78810) We recently updated `SyncBatchNorm` to support empty input batches. The new code removes stats from ranks with empty inputs. However, this change breaks CUDA graph capture as it forces CPU sync. This commit uses `is_current_stream_capturing()` to guard the new code path, and only run the new code when not capturing CUA Graphs. To support empty inputs with CUDA graph capturing, we might need to update CUDA kernels for `batch_norm_backward_elemt` and `batch_norm_gather_stats_with_counts`. See #78656. Fixes #78549 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78666 Approved by: https://github.com/albanD
Author
Parents
Loading