pytorch
cec08e70 - To add warm-up scheduler to optim (#60836)

Commit
3 years ago
To add warm-up scheduler to optim (#60836) Summary: Warm up of learning rate scheduling has initially been discussed by Priya et. al. in the paper: https://arxiv.org/pdf/1706.02677.pdf . In the section 2.2 of the paper they discussed and proposed idea of warming up learning schedulers in order to prevent big variance / noise in the learning rate. Then idea has been further discussed in the following papers: * Akilesh Gotmare et al. https://arxiv.org/abs/1810.13243 * Bernstein et al http://proceedings.mlr.press/v80/bernstein18a/bernstein18a.pdf * Liyuan Liu et al: https://arxiv.org/pdf/1908.03265.pdf There are two type of popularly used learning rate warm up ideas * Constant warmup (start with very small constant learning rate) * Linear Warmup ( start with small learning rate and gradually increase) In this PR we are adding warm up as learning rate scheduler. Note that learning rates are chainable, which means that we can merge warmup scheduler with any other learning rate scheduler to make more sophisticated learning rate scheduler. ## Linear Warmup Linear Warmup is multiplying learning rate with pre-defined constant - warmup_factor in the first epoch (epoch 0). Then targeting to increase this multiplication constant to one in warmup_iters many epochs. Hence we can derive the formula at i-th step to have multiplication constant equal to: warmup_factor + (1-warmup_factor) * i / warmup_iters Moreover, the fraction of this quantity at point i to point i-1 will give us 1 + (1.0 - warmup_factor) / [warmup_iters*warmup_factor+(i-1)*(1-warmup_factor)] which is used in get_lr() method in our implementation. Below we provide an example how to use linear warmup scheduler and to give an example to show how does it works. ```python import torch from torch.nn import Parameter from torch.optim import SGD from torch.optim.lr_scheduler import WarmUpLR model = [Parameter(torch.randn(2, 2, requires_grad=True))] optimizer = SGD(model, 0.1) scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=10, warmup_method="linear") for epoch in range(15): print(epoch, scheduler.get_last_lr()[0]) optimizer.step() scheduler.step() ``` ``` 0 0.010000000000000002 1 0.019000000000000003 2 0.028000000000000008 3 0.03700000000000001 4 0.04600000000000001 5 0.055000000000000014 6 0.06400000000000002 7 0.07300000000000002 8 0.08200000000000003 9 0.09100000000000004 10 0.10000000000000005 11 0.10000000000000005 12 0.10000000000000005 13 0.10000000000000005 14 0.10000000000000005 ``` ## Constant Warmup Constant warmup has straightforward idea, to multiply learning rate by warmup_factor until we reach to epoch warmup_factor, then do nothing for following epochs ```python import torch from torch.nn import Parameter from torch.optim import SGD from torch.optim.lr_scheduler import WarmUpLR model = [Parameter(torch.randn(2, 2, requires_grad=True))] optimizer = SGD(model, 0.1) scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant") for epoch in range(10): print(epoch, scheduler.get_last_lr()[0]) optimizer.step() scheduler.step() ``` ``` 0 0.010000000000000002 1 0.010000000000000002 2 0.010000000000000002 3 0.010000000000000002 4 0.010000000000000002 5 0.10000000000000002 6 0.10000000000000002 7 0.10000000000000002 8 0.10000000000000002 9 0.10000000000000002 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/60836 Reviewed By: saketh-are Differential Revision: D29537615 Pulled By: iramazanli fbshipit-source-id: d910946027acc52663b301f9c56ade686e62cb69
Author
Parents
Loading