`GradScaler` recomputes `optimizer_state["found_inf_per_device"]` before `optimizer.step` (#97415)
I found a discrepancy between non-fused and fused optimizers, which is to use `optimizer_state["found_inf"]` or to recompute `found_inf`.
- non fused: https://github.com/pytorch/pytorch/blob/e64ddd1ab9d46cfc921c19269969ffc5cd7d6f6c/torch/cuda/amp/grad_scaler.py#L289
- fused: https://github.com/pytorch/pytorch/blob/e64ddd1ab9d46cfc921c19269969ffc5cd7d6f6c/torch/cuda/amp/grad_scaler.py#L353
- where `_check_inf_per_device` is https://github.com/pytorch/pytorch/blob/e64ddd1ab9d46cfc921c19269969ffc5cd7d6f6c/torch/cuda/amp/grad_scaler.py#L564-L573
The other way to align the behavior is to use the existing `found_inf` in https://github.com/pytorch/pytorch/blob/e64ddd1ab9d46cfc921c19269969ffc5cd7d6f6c/torch/cuda/amp/grad_scaler.py#L353.
I'd say this PR is for the sake of "safety" and the alternative is to keep the existing behavior.
I honestly have no idea if it's expected to double-check the sanity of gradients in `GradScaler.step`.
---
what I've observed in huggingface/transformers T5-base example so far seems like that non-fused optimizers lead to invalid parameters while the fused not.
The cause seems to be that `gradients` become inf/nan before `GradScaler.step(optimizer)` after `GradScaler._unscale_grads_` (more precicely, the call of `torch._amp_foreach_non_finite_check_and_unscale_`) in the script of the issue linked below, i.e. the gradient clipping and/or unscaling lead to inf/nan as these happen after the grad check. See
https://github.com/pytorch/pytorch/blob/788300cc2aa096d8d5c1e7fbfc87e5439a338251/aten/src/ATen/native/cuda/AmpKernels.cu#L165-L174.
Fixes #96755 🙏
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97415
Approved by: https://github.com/ngimel, https://github.com/janeyx99