pytorch
d01bf1d1 - [FSDP] Introduce `ModuleWrapPolicy` for simplicity (#88450)

Commit
1 year ago
[FSDP] Introduce `ModuleWrapPolicy` for simplicity (#88450) **BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading