pytorch
481a334b - [FSDP][3/N] Refactor `summon_full_params` unit tests (#92298)

Commit
2 years ago
[FSDP][3/N] Refactor `summon_full_params` unit tests (#92298) **Overview** - This PR refactors the `summon_full_params()` unit tests to prepare for `unshard_params()` by consolidating redundant tests and improving others. - This PR enables `CPUOffload(offload_params=True)` + `NO_SHARD` + `writeback=True`. - This PR provides an improved error message when calling `summon_full_params()` from an invalid context (i.e. from forward, backward, or in `summon_full_params()`). **Details** <details> <summary>Existing Unit Tests</summary> `test_summon_full_param_writeback()` with `world_size=1` `test_summon_full_param_writeback()` with `world_size=2` - Tests that `writeback=True` persists write and that `writeback=False` does not persist write when modifying a root FSDP instance's `flat_param` (`modify_outer=True`) or a non-root FSDP instance's `flat_param` (`modify_outer=False`); additionally configures with `mixed_precision` and `use_orig_params` - `CPUOffload(offload_params=True)` + `world_size=1` is not tested because it is not supported. - The write inside `summon_full_params()` is on the `flat_param` itself, which is not the expected usage. `test_summon_full_param_shard_value()` - Tests that reconstructing the `flat_param` (by re-flattening and chunking parameters) inside `summon_full_params()` gives the same as the originally constructed `flat_param` when using a single FSDP instance - This test seems to exercise the FSDP sharding algorithm, not the specification of `summon_full_params()`. The only relevant part being implicitly tested is that `model.parameters()` order is preserved. - This test assumes the current FSDP sharding algorithm. `test_summon_full_param_recursive()` - Tests that `recurse=True` recursively applies to all FSDP instances and that `recurse=False` does not - This test assumes the current FSDP sharding algorithm. `test_cannot_summon_full_params_from_forward()` `test_cannot_summon_full_params_from_backward()` - Tests that calling `summon_full_params()` from inside the forward or backward raises an error - The error message leaks `FlatParamHandle` to the user. I provided a better error in this PR. `test_summon_full_params_respects_reshard_after_forward()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) - This test depends on FSDP internals (`flat_param._full_param_padded.storage().size()`). `test_summon_single_param()` - Tests that writing to padding with `writeback=True` does not persist those writes (doing so by using a singleton `(1, 1)` parameter that gets flattened and padded to `(2,)`) - This test name is misleading. `test_summon_full_params_equivalence()` - Tests `writeback`, `rank0_only`, and `offload_to_cpu` with `writeback=not rank0_only`, using `CPUOffload(offload_params=True)` and including a `torch.cuda._sleep(int(1e6))` _after_ the write in `summon_full_params()` - The PR introducing this test said that the `torch.cuda._sleep(int(1e6))` exercised the stream synchronization in `summon_full_params()`--namely that the current stream waits for the all-gather stream after all-gathering the parameters. I did not follow conceptually how that works since the `torch.cuda._sleep()` call happens after both the all-gather and write and is in the default stream, which seems to be after the relevant ops. If we clarify this, I can re-incorporate this into the unit tests. Doing so is not a high priority since `summon_full_params()` unshards in the default stream now and does not require stream synchronization. - This unit test has overlap with `test_summon_full_param_writeback()` and can be coalesced. `test_summon_from_non_fsdp()` - Tests calling `summon_full_params()` with default args on a non-FSDP root module exposes the original parameters correctly - This test actually covers much of the specification since checking for original parameter equivalence includes shape, value, device, etc. checking. `test_reshard_outside_forward_backward_iteration()` - Tests that calling `summon_full_params()` after forward preserves whether the padded unsharded `flat_param` data is freed or not (like `reshard_after_forward`) and that calling `summon_full_params()` after backward preserves that the padded unsharded `flat_param` data are freed; additionally configures `mixed_precision` - This test strictly dominates `test_summon_full_params_respects_reshard_after_forward()` in strictness since it includes the check after backward as well. `test_params_are_unflattenned()` - Tests that original parameters are exposed with the unflattened shape factoring in `rank0_only` (e.g. including that nonzero ranks reshard early when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_params_count_and_value()` - Tests that original parameters are all exposed and with the correct values factoring in `rank0_only` (e.g. including that nonzero ranks do not expose the original parameters when `rank0_only=True`) and that with `offload_to_cpu=True`, the `flat_param`s are moved back to GPU after exiting the context; additionally configures `mixed_precision` `test_raises_rank0_with_writeback()` - Tests that `rank0_only` + `writeback=True` raises an error `test_named_parameters_buffers()` - Tests that `named_parameters()` and `named_buffers()` return clean names (without FSDP prefixes) inside `summon_full_params()` `test_with_grads_core()` - Tests `with_grads=True` by comparing against DDP `test_with_grads_none_grads()` - Tests `with_grads=True` when ranks' `FlatParameter`s have `None` gradient </details> <details> <summary>New Unit Tests</summary> `test_unshard_params_writeback_no_shard()` (with `world_size=1`) `test_unshard_params_writeback()` (with `world_size=2`) - Tests the `writeback` argument (using the default value for all others) `test_unshard_params_param_data_no_shard()` (with `world_size=1`) `test_unshard_params_param_data()` (with `world_size=2`) - Tests that parameters are exposed correctly for `recurse=True` and all other argument configs for a non-FSDP root module `test_unshard_singleton_param_writeback()` - Tests `writeback=True` for a singleton parameter, which includes testing that writing to padding does not persist `test_unshard_params_respects_reshard()` - Tests that unsharding parameters respects the expected reshard behavior between forward and backward as well as after backward `test_unshard_params_recurse()` - Tests the `recurse` argument (using default for all others) `test_offload_to_cpu_no_shard_raises()` - Tests that `offload_to_cpu=True` with `NO_SHARD` raises an error </details> <details> <summary>Summary of Unit Test Changes</summary> - `test_summon_full_param_writeback` -> `test_unshard_params_writeback()` - `test_summon_full_params_equivalence()`, `test_params_are_unflattenned()`, `test_params_count_and_value()` -> `test_unshard_params_param_data()` - `test_summon_full_params_respects_reshard_after_forward()`, `test_reshard_outside_forward_backward_iteration()` -> `test_unshard_params_respects_reshard()` - `test_summon_full_param_recursive()` -> `test_unshard_params_recurse()` - `test_named_parameters_and_buffers()` unchanged - `test_with_grads_core()` unchanged - `test_with_grads_none_grads()` unchanged - `test_cannot_summon_full_params_from_forward()`, `test_cannot_summon_full_params_from_backward()` -> `test_unshard_params_from_forward_raises()`, `test_unshard_params_from_backward_raises()` - `test_raises_rank0_with_writeback()` -> `test_rank0_only_with_writeback_raises()` - `test_offload_to_cpu_no_shard_raises()` new - `test_summon_full_param_shard_value()` removed </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/92298 Approved by: https://github.com/rohan-varma
Author
Andrew Gu
Committer
Parents
Loading