pytorch
e858f6ee - torch.nn.utils.clip_grad_norm_: remove device syncs (#61042)

Commit
4 years ago
torch.nn.utils.clip_grad_norm_: remove device syncs (#61042) Summary: Fixes https://github.com/pytorch/pytorch/issues/60691 ### Changes Per the discussion in the above issue, this PR makes 2 changes: 1. When `error_if_nonfinite=False`, the NaN/Inf checks are truly skipped, and no device synchronization occurs. - Additionally, when performing the checks, the 2 results are combined with `torch.logical_or` to incur only a single sync (instead of 2 in the happy/finite path). 2. The `clip_coef` conditional is removed, in favor of a call to `clamp(..., max=1.0)` and an unconditional multiplication. ### Testing - The existing unit tests for `clip_grad_norm_` pass. - I have manually profiled the example program from https://github.com/pytorch/pytorch/issues/60691, and verified that: - No synchronizations occur when using `error_if_nonfinite=False`. - A single synchronization occurs when using `error_if_nonfinite=True`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61042 Reviewed By: mrshenli Differential Revision: D29764096 Pulled By: jbschlosser fbshipit-source-id: db594b24608d16374b91bcbb9469046dfeeb152d
Author
Parents
Loading