pytorch
75cb99e5 - [optim] Widen the cases for defaulting to foreach (#95820)

Commit
1 year ago
[optim] Widen the cases for defaulting to foreach (#95820) Big OOP correction continued. Also added a test this time to verify the defaulting was as expected. The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820 Approved by: https://github.com/albanD
Author
Committer
Parents
Loading