pytorch
40246fa6 - Gradient scaling API (#26512)

Commit
4 years ago
Gradient scaling API (#26512) Summary: This PR implements the gradient scaling API that mruberry, jjsjann123, ngimel, zdevito, gchanan and I have been discussing. Relevant issue/RFC: https://github.com/pytorch/pytorch/issues/25081. Volume-wise, this PR is mostly documentation and tests. The Python API (found entirely in `torch/cuda/amp/amp_scaler.py`) is lightweight . The exposed functions are intended to make the implementation and control flow of gradient scaling convenient, intuitive, and performant. The API is probably easiest to digest by looking at the documentation and examples. `docs/source/amp.rst` is the homepage for the Automatic Mixed Precision package. `docs/source/notes/amp_examples.rst` includes several examples demonstrating common but not-immediately-obvious use cases. Examples are backed by tests in `test_cuda.py` (and thankfully the tests pass :P). Two small utility kernels have been added in `native/cuda/AmpKernels.cu` to improve performance and avoid host-device synchronizations wherever possible. Existing optimizers, both in the wild and in Pytorch core, do not need to change to use the scaling API. However, the API was also designed to establish a contract between user scripts and optimizers such that writers of _new_ custom optimizers have the control points they need to implement fast, optionally sync-free updates. User scripts that obey the scaling API can drop such custom optimizers in and reap performance benefits without having to change anything aside from the optimizer constructor itself. [I know what the contract with custom optimizers should be](https://github.com/pytorch/pytorch/blob/35829f24ef2a71eacedcd6307b2fe9325e4a6a94/torch/cuda/amp/amp_scaler.py#L179-L184), but I'm waiting for review on the rest of the API before I go about documenting it (it will be given a dedicated section in `docs/source/notes/amp_examples.rst`. Currently, the gradient scaling examples do not include the auto-casting API as discussed in https://github.com/pytorch/pytorch/issues/25081. The gradient scaling API is intended to be orthogonal/modular relative to autocasting. Without auto-casting the gradient scaling API is fully use-_able_, but not terribly use-_ful_, so it's up to you guys whether you want to wait until auto-casting is ready before merging the scaling API as well. ### Todo - [ ] How do I get c10 registered status for my two custom kernels? They're very simple. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26512 Differential Revision: D19859905 Pulled By: mruberry fbshipit-source-id: bb8ae6966214718dfee11345db824389e4286923
Parents
Loading