pytorch
45bf3f62 - Optimized EMA implementation (#94820)

Commit
2 years ago
Optimized EMA implementation (#94820) This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies. This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes. Example usage: ``` model = Model().to(device) opt = torch.optim.Adam(model.parameters()) opt = EMAOptimizer(opt, device, 0.9999) for epoch in range(epochs): training_loop(model, opt) regular_eval_accuracy = evaluate(model) with opt.swap_ema_weights(): ema_eval_accuracy = evaluate(model) ``` Here are some benchmarks (time per iteration) on various torchvision models: |model|this PR iteration time |swa_utils.AveragedModel iteration time| iteration speedup | |-----|-----------------------------|-----------------------|---------------------------------------------| | | | | | |regnet_x_1_6gf|62.73 |67.998 |1.08 | |regnet_x_3_2gf|101.75 |109.422 |1.08 | |regnet_x_400mf|25.13 |32.005 |1.27 | |regnet_x_800mf|33.01 |37.466 |1.13 | |regnet_x_8gf|128.13 |134.868 |1.05 | |regnet_y_16gf|252.91 |261.292 |1.03 | |regnet_y_1_6gf|72.14 |84.22 |1.17 | |regnet_y_3_2gf|99.99 |109.296 |1.09 | |regnet_y_400mf|29.53 |36.506 |1.24 | |regnet_y_800mf|37.82 |43.634 |1.15 | |regnet_y_8gf|196.63 |203.317 |1.03 | |resnet101|128.80 |137.434 |1.07 | |resnet152|182.85 |196.498 |1.07 | |resnet18|29.06 |29.975 |1.03 | |resnet34|50.73 |53.443 |1.05 | |resnet50|76.88 |80.602 |1.05 | |resnext101_32x8d|277.29 |280.759 |1.01 | |resnext101_64x4d|269.56 |281.052 |1.04 | |resnext50_32x4d|100.73 |101.102 |1.00 | |shufflenet_v2_x0_5|10.56 |15.419 |1.46 | |shufflenet_v2_x1_0|13.11 |18.525 |1.41 | |shufflenet_v2_x1_5|18.05 |23.132 |1.28 | |shufflenet_v2_x2_0|25.04 |30.008 |1.20 | |squeezenet1_1|14.26 |14.325 |1.00 | |swin_b|264.52 |274.613 |1.04 | |swin_s|180.66 |188.914 |1.05 | |swin_t|108.62 |112.632 |1.04 | |swin_v2_s|220.29 |231.153 |1.05 | |swin_v2_t|127.27 |133.586 |1.05 | |vgg11|95.52 |103.714 |1.09 | |vgg11_bn|106.49 |120.711 |1.13 | |vgg13|132.94 |147.063 |1.11 | |vgg13_bn|149.73 |165.256 |1.10 | |vgg16|158.19 |172.865 |1.09 | |vgg16_bn|177.04 |192.888 |1.09 | |vgg19|184.76 |194.194 |1.05 | |vgg19_bn|203.30 |213.334 |1.05 | |vit_b_16|217.31 |219.748 |1.01 | |vit_b_32|69.47 |75.692 |1.09 | |vit_l_32|223.20 |258.487 |1.16 | |wide_resnet101_2|267.38 |279.836 |1.05 | |wide_resnet50_2|145.06 |154.918 |1.07 | You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow. This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo). If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests. Credits to @szmigacz for the implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820 Approved by: https://github.com/janeyx99
Author
Committer
Parents
Loading