perf(inductor): use for loop with shortcut in `Optimizer`s to speedup against list comprehensions (e.g. complex conversion) (#110613)
Fully fixes: https://github.com/pytorch/pytorch/issues/110506
Depends: https://github.com/pytorch/pytorch/pull/110607
Potential merge conflicts:
- https://github.com/pytorch/pytorch/pull/110339
- https://github.com/pytorch/pytorch/pull/110345
- https://github.com/pytorch/pytorch/pull/110454
Related:
- https://github.com/pytorch/pytorch/issues/110606 (we can apply the improvements here orthogonally to the complex support)
### Results
Benchmark: 100 params.
Breakdowns (float32, dynamo):
```
Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s
```
Notes:
1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
2. Adamax is not actually compiled (it is currently disabled).
3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing.
<details>
This PR:
```
Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s
```
Main
```
Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s
```
Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).
</details>
### Benchmark Script
```python
import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop
OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]
NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []
for optim_cls in OPTIMS:
for dtype in DTYPES:
torch._dynamo.reset()
# torch._inductor.metrics.reset()
input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
model = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
)
model(input).sum().abs().backward()
opt_compiled = optim_cls(model.parameters(), **kwargs)
compiled_step = torch.compile(opt_compiled.step)
with torch.set_grad_enabled(False):
start_time = time.time()
compiled_step()
summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")
print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())
for s in summary:
print(s)
```
CC: @janeyx99 @mlazos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110613
Approved by: https://github.com/janeyx99