benchmark
23aa052a - Deprecate device-specific GradScaler autocast API (#126527)

Commit
1 year ago
Deprecate device-specific GradScaler autocast API (#126527) Summary: # Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. X-link: https://github.com/pytorch/pytorch/pull/126527 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang Reviewed By: PaliC Differential Revision: D57838085 fbshipit-source-id: 09a29e2535e66643d212276779605c573391666f
Author
Parents
Loading