pytorch
2b21e776 - Added optimizers based on multi tensor apply (#45299)

Commit
4 years ago
Added optimizers based on multi tensor apply (#45299) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45299 Adding a new namespace `torch.optim._multi_tensor` with a bunch of updated optimizers. Those optimizers are using _foreach APIs which improve performance significantly. ### Tests - updated existing tests to use both optimizers - added `test_multi_tensor_optimizers` test to verify correctness. ### Perf results **Adam** timeit: 42.69 ms --> 10.16 ms autorange: 41.96 ms --> 10.28 ms **AdamW** timeit: 51.38 ms --> 15.63 ms autorange: 50.82 ms --> 16.07 ms **SGD** timeit: 6.28 ms --> 4.40 ms autorange: 6.13 ms --> 4.73 ms **RMSprop** timeit: 28.63 ms --> 5.89 ms autorange: 28.27 ms --> 5.76 ms **Rprop** timeit: 213.30 --> 178.42 autorange: 212.03 --> 178.03 **ASGD** timeit: 21.67 --> 9.33 autorange: 21.64 --> 9.27 **Adamax** timeit: 55.60 --> 48.29 autorange: 55.22 -> 49.13 **Rerf Script used** ``` import torch import time import torch.optim as optim from torch.autograd import Variable from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, StepLR import torch.nn as nn import time import torchvision import torch.utils._benchmark as benchmark_utils device = "cuda" model = torchvision.models.resnet.resnet101(pretrained=True).to(device) targets = torch.randint(0, 1000, (100, 100), device=device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=1e-3) # <----------------------- optimizer. # would compare optim.SGD vs optim._multi_tensor.SGD running_loss = 0.0 target = torch.empty(128, dtype=torch.long, device=device).random_(5) optimizer.zero_grad() inputs = torch.rand(128, 3, 100, 100, device=device , requires_grad=True) outputs = model(inputs) loss = criterion(outputs, target) loss.backward() optimizer.step() running_loss += loss.item() def main(): timer = benchmark_utils.Timer( stmt="optimizer.step()", globals=globals(), label="str(optimizer)", ) for i in range(1): print(f"Run: {i}\n{'-' * 40}") print(f"timeit:\n{timer.timeit(1000)}\n") print(f"autorange:\n{timer.blocked_autorange()}\n\n") if __name__ == "__main__": main() ``` Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D23931987 Pulled By: izdeby fbshipit-source-id: 582134ef2d402909d27d89a45c5b588fb7130ea1
Author
Iurii Zdebskyi
Parents
Loading