torch.nn.utils.clip_grad_norm_: remove device syncs (#61042)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60691
### Changes
Per the discussion in the above issue, this PR makes 2 changes:
1. When `error_if_nonfinite=False`, the NaN/Inf checks are truly skipped, and no device synchronization occurs.
- Additionally, when performing the checks, the 2 results are combined with `torch.logical_or` to incur only a single sync (instead of 2 in the happy/finite path).
2. The `clip_coef` conditional is removed, in favor of a call to `clamp(..., max=1.0)` and an unconditional multiplication.
### Testing
- The existing unit tests for `clip_grad_norm_` pass.
- I have manually profiled the example program from https://github.com/pytorch/pytorch/issues/60691, and verified that:
- No synchronizations occur when using `error_if_nonfinite=False`.
- A single synchronization occurs when using `error_if_nonfinite=True`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61042
Reviewed By: mrshenli
Differential Revision: D29764096
Pulled By: jbschlosser
fbshipit-source-id: db594b24608d16374b91bcbb9469046dfeeb152d