Add API to set a module as a leaf node when recursively setting Z3 hooks (#4966)
ZeRO3 does not work with MoE models because the order of executing
modules can change at every forward/backward pass (#4094, #4808).
This PR adds an API to stop breaking down a module for parameter
fetching. The following shows an example of the usage:
```python
import torch
import deepspeed
import deepspeed.comm as dist
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
model_id = "mistralai/Mixtral-8x7B-v0.1"
ds_config = {
"bf16": {
"enabled": True,
},
"zero_optimization": {
"stage": 3,
},
"train_micro_batch_size_per_gpu": 1,
}
hfdsc = HfDeepSpeedConfig(ds_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
model.eval()
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module
inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to("cuda")
outputs = model.generate(inputs, max_new_tokens=200)
output_str = tokenizer.decode(outputs[0])
if dist.get_rank() == 0:
print(f"output: {output_str}")
```
By passing names of modules to `set_z3_leaf_modules`, DeepSpeed engine
stops breaking down the module.
In this example, `MixtralSparseMoeBlock` has multiple experts as its
submodule. Using `set_z3_leaf_modules`, the DeepSpeed engine fetches
parameters of all the submodules when pre-fetching the parameters of
`MixtralSparseMoeBlock`.