[FSDP][6/N] Check valid param freezing for `ModuleWrapPolicy` (#104427)
This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
- There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.
<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>
```
LoraModel(
(embed_tokens): Embedding(100, 32)
(layers): ModuleList(
(0-3): 4 x LoraDecoder(
(attn): LoraAttention(
(q_proj): Linear(in_features=32, out_features=32, bias=False)
(lora_A): Linear(in_features=32, out_features=8, bias=False)
(lora_B): Linear(in_features=8, out_features=32, bias=False)
(k_proj): Linear(in_features=32, out_features=32, bias=False)
(v_proj): Linear(in_features=32, out_features=32, bias=False)
(o_proj): Linear(in_features=32, out_features=32, bias=False)
)
(mlp): LoraMLP(
(proj1): Linear(in_features=32, out_features=128, bias=False)
(proj2): Linear(in_features=128, out_features=32, bias=False)
)
(inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
)
(norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
'embed_tokens',
'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
'norm',
'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104427
Approved by: https://github.com/ezyang