pytorch
b4b1b0d5 - [FSDP] Support initialization of modules on meta device

Commit
2 years ago
[FSDP] Support initialization of modules on meta device Adds support for FSDP to initialize modules on a "meta" device, either through the in-tree device="meta" + reset_parameters approach or [torchdistx](https://github.com/cbalioglu/torchdistx)'s deferred_init + materialize_module APIs. Current constraints: 1. Argument added is `param_init_fn` which is a single lambda that specifies how a module should be initialized. If needed to initialize different modules in different ways, user can do this in the function they specify, or we can enhance FSDP to take a Dict[str, Callable] to support this more natively. 2. Does not cleanly support `ignored_modules` at the moment. Concretely, if a module is not sharded at all, it may remain unitialized, and user will have to initialize it on their own. Will dig into this as follow up work. Currently, unittests only contain device="meta" + reset_parameters approach because torchdistx is not available in CI. Will work with CI team in order to allow unittesting with torchdistx. Pull Request resolved: https://github.com/pytorch/pytorch/pull/75880 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading