Add Model Parallel Support to ZeRO (#61370)
Summary:
**Overview:**
The existing `ZeroRedundancyOptimizer` implementation assumes that all model parameters are stored on the same device (due to the recent [refactor](https://github.com/pytorch/pytorch/pull/59834)). This change allows model parameters to be sharded across multiple devices, as in the DDP with Model Parallelism example [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
The only logic affected is the bucketing strategy used when `parameters_as_bucket_view=True`. Let `n` denote the world size and `k` denote the number of devices per process.
- Previously, `k = 1`, and `self._buckets` was a `List[torch.Tensor]`, where `self._buckets[j]` is a tensor (i.e. bucket) containing the parameters assigned to rank `j` for `j = 0, ..., n - 1`.
- Now, `self._buckets` is a `List[List[torch.Tensor]]`, where `self._buckets[i][j]` is a tensor containing the parameters stored on device `i` assigned to rank `j` for `i = 0, ..., k - 1` and `j = 0, ..., n - 1`.
This bucket construction uses an auxiliary data structure `self._device_to_per_rank_params`, which is a `Dict[torch.device, List[List[torch.Tensor]]]`. It maps:
- `dev_0` to `[rank 0's assigned parameters on dev_0, rank 1's assigned parameters on dev_1, ...]`,
- `...`
- `dev_{k-1}` to `[rank 0's assigned parameters on dev_{k-1}, rank 1's assigned parameters on dev_{k-1}, ...]`
I removed the invariant checker `_verify_same_param_device()` and its corresponding test since it is no longer an invariant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61370
Test Plan: I added a new test `test_zero_model_parallel()` that checks for parity between a DDP model with model parallelism using `ZeroRedundancyOptimizer` and a local model with the same architecture using a local optimizer. I also verified that the existing tests still pass.
Reviewed By: soulitzer
Differential Revision: D29637132
Pulled By: andwgu
fbshipit-source-id: 07112959fa4e94a3f40e67e88cbb58ce3cd1e033