pytorch
5dd07324 - [ZeRO] Add ctor support for multiple param groups (#72578)

Commit
2 years ago
[ZeRO] Add ctor support for multiple param groups (#72578) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72578 **Overview** This adds `ZeroRedundancyOptimizer` constructor support for multiple parameter groups (i.e. passing an `iterable` of `dict`s instead of an `iterable` of `torch.Tensor` as the `parameters` argument) to mirror the API for non-sharded optimizers. Fixes https://github.com/pytorch/pytorch/issues/71347 and https://github.com/pytorch/pytorch/issues/59973. This modifies `test_collect_shards()` to skip if ROCm. **Test Plan** I adjusted the existing constructor test, and I added a test for parity between constructing with two parameter groups up front versus constructor with one parameter group and adding the second parameter group after (via `add_param_group()`) versus a non-sharded optimizer. Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D34106940 Pulled By: awgu fbshipit-source-id: 7e70fc0b3cec891646e0698eaedf02ff4354c128 (cherry picked from commit 40f2d45172ba3286b64000a466e42c055cca8ddc)
Author
Committer
Parents
Loading