[RELAND] Added optimizers based on multi tensor apply (#45408)
Summary:
Original PR https://github.com/pytorch/pytorch/pull/45299. The present PR fixes minor bugs that caused revert.
Adding a new namespace `torch.optim._multi_tensor` with a bunch of updated optimizers. Those optimizers are using _foreach APIs which improve performance significantly.
### Tests
- updated existing tests to use both optimizers
- added `test_multi_tensor_optimizers` test to verify correctness.
### Perf results
**Adam**
timeit: 42.69 ms --> 10.16 ms
autorange: 41.96 ms --> 10.28 ms
**AdamW**
timeit: 51.38 ms --> 15.63 ms
autorange: 50.82 ms --> 16.07 ms
**SGD**
timeit: 6.28 ms --> 4.40 ms
autorange: 6.13 ms --> 4.73 ms
**RMSprop**
timeit: 28.63 ms --> 5.89 ms
autorange: 28.27 ms --> 5.76 ms
**Rprop**
timeit: 213.30 --> 178.42
autorange: 212.03 --> 178.03
**ASGD**
timeit: 21.67 --> 9.33
autorange: 21.64 --> 9.27
**Adamax**
timeit: 55.60 --> 48.29
autorange: 55.22 -> 49.13
**Rerf Script used**
```
import torch
import time
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau, StepLR
import torch.nn as nn
import time
import torchvision
import torch.utils._benchmark as benchmark_utils
device = "cuda"
model = torchvision.models.resnet.resnet101(pretrained=True).to(device)
targets = torch.randint(0, 1000, (100, 100), device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3) # <----------------------- optimizer.
# would compare optim.SGD vs optim._multi_tensor.SGD
running_loss = 0.0
target = torch.empty(128, dtype=torch.long, device=device).random_(5)
optimizer.zero_grad()
inputs = torch.rand(128, 3, 100, 100, device=device , requires_grad=True)
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
def main():
timer = benchmark_utils.Timer(
stmt="optimizer.step()",
globals=globals(),
label="str(optimizer)",
)
for i in range(1):
print(f"Run: {i}\n{'-' * 40}")
print(f"timeit:\n{timer.timeit(1000)}\n")
print(f"autorange:\n{timer.blocked_autorange()}\n\n")
if __name__ == "__main__":
main()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45408
Reviewed By: gchanan
Differential Revision: D23956680
Pulled By: izdeby
fbshipit-source-id: c5eab7bf5fce14a287c15cead1cdc26e42cfed94