pytorch
d6b58d69 - [FSDP()][23/N] Refactor handle attr initialization (#87938)

Commit
2 years ago
[FSDP()][23/N] Refactor handle attr initialization (#87938) **`_init_param_attributes()` -> `init_flat_param_attributes()`** We move `_init_param_attributes()` to `FlatParamHandle.init_flat_param_attributes()` (as already marked as to-do during previous refactoring). **`_reset_lazy_init()`** We no longer delete `_local_shard` from each `FlatParameter` in `_reset_lazy_init()`. **Analysis** Thus, the two semantic differences are that we remove the initial `if hasattr(p, "_local_shard")` early return in `_init_param_attributes()` and the `delattr(p, "_local_shard")` in `_reset_lazy_init()`. This is safe because - If we never call `_reset_lazy_init()`, then `init_flat_param_attributes()` is only called once. There is no opportunity for an early return. - If we call `_reset_lazy_init()`, then `init_flat_param_attributes()` will be called again in the next `_lazy_init()`. However, since we removed the early return, all of the attributes initialized in `init_flat_param_attributes()` simply get re-initialized and override any existing attributes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87938 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading