[optim] FusedAdam/W accepts lr: Tensor without h2ds (#106916)
Starts addressing #106802
This PR also conveniently does some BE:
- Fixes a bug in adamw where we use amsgrad instead of per group amsgrad
- Brings the impls of adamw and adam closer to correctness and to each other
I couldn't fully remove the .pyi's because mypy was going to complain about the entire files which scared me and shouldn't go in this PR anyway.
Test plan:
- Add tests to ensure that lr could be passed as a Tensor
- Did some profiling of the below code (runs 1k iterations of step for Adam)
```
import torch
from torch.testing._internal.common_utils import TestCase
param = torch.rand(2, 3, dtype=torch.float, device='cuda:0', requires_grad=True)
param.grad = torch.rand_like(param)
lr = torch.tensor(.001, device='cuda:0')
opt = torch.optim.Adam([param], lr=lr, fused=True)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
for _ in range(1000):
opt.step()
print(p.key_averages().table(sort_by="cpu_time_total"))
```
Before my change:
<img width="1381" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/cfc5175a-0f41-4829-941f-342554f3b152">
After my change (notice there are no d2h syncs and the CPU time is lower!):
![image](https://github.com/pytorch/pytorch/assets/31798555/726d7e66-dcff-4a4f-8a75-e84329961989)
Next steps long term:
- have all capturable foreach + forloop impls in Adam(W) handle tensor LR
- have all capturable impls handle tensor LR
- have all impls handle tensor LR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106916
Approved by: https://github.com/albanD