Add get and set APIs for the ZeRO-3 partitioned parameters (#4681)
The DeepSpeed currently supports a set of debugging APIs to
[get](https://deepspeed.readthedocs.io/en/latest/zero3.html#debugging)
and
[set](https://deepspeed.readthedocs.io/en/latest/zero3.html#modifying-partitioned-states)
the **full** model states (parameters, gradients, and optimizer states).
However, in some scenarios, only **local states** are needed, for
example, when pruning some model layers based on a local criterion.
After calling `model_engine.step()`, we need to apply the local mask to
the partitioned parameters owned by each process. Therefore, I am
submitting this PR to introduce some new APIs for `get` and `set` ZeRO-3
partial model states.
### APIs intro
```python
def safe_get_local_fp32_param(param):
"""Get the fp32 partitioned parameter."""
def safe_get_local_grad(param):
"""Get the fp32 gradient of a partitioned parameter."""
def safe_get_local_optimizer_state(param, optim_state_key):
"""Get the fp32 optimizer state of a partitioned parameter."""
def safe_set_local_fp32_param(param, value):
"""Update the partitioned fp32 parameter."""
def safe_set_local_optimizer_state(param, value, optim_state_key):
"""Update the fp32 optimizer state of a partitioned parameter."""
```
### Usage
```python
# local API
from deepspeed.utils import (
safe_get_local_fp32_param,
safe_get_local_grad,
safe_get_local_optimizer_state,
safe_set_local_fp32_param,
safe_set_local_optimizer_state
)
```
### TODO
- [x] Add local APIs
- [x] Add UTs
- [x] Update Docs
@tjruwase
---------
Signed-off-by: yliu <test@do_not_reply@neuralstudio.intel.com>
Co-authored-by: yliu <test@do_not_reply@neuralstudio.intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>