pytorch
a3f3ec8f - [FSDP+dynamo]: forward treats parameter-views as params (#88781)

Commit
2 years ago
[FSDP+dynamo]: forward treats parameter-views as params (#88781) Dynamo+AotAutograd needs a way to wrap all tensors (whether inputs or params/buffers) in FakeTensor wrappers, and FSDP's mangling of parameters hides them from this wrapping. This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model. Perf/memory validation and possibly optimization is the next step. `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap` The problem: Dynamo (Actually aot_autograd) trips up with FSDP becuase it must wrap all input tensors in FakeTensor wrappers, and it only knows to wrap graph inputs or named_(parameters, buffers). FSDP's pre_forward hook sets views (which are not nn.param) into the flatparam as attrs on the module with the same name as the original param, but they will not show up in named_parameters. - in use_orig_params mode, FSDP still de-registers params during pre-forward hook, then re-registers them post-forward - during forward (between the hooks), the params are setattr'd on the module as regular view tensors, not nn.Parameters - note: use_orig_params is the recommended way to use FSDP, and use_orig_params=False is being deprecated. So i only consider use_orig_params=True for this enablement The solution: - adding them to named_buffers is not possible because it interferes with how FSDP's `_apply` works - since they are not actual nn.parameters, register_parameter will complain about registering them - simply seting `module._parameters[name] = view` seems to be a viable workaround, despite being hacky, and FSDP code does modify _parameters directly already. Note: Manual checkpointing still isn't working with FSDP+dynamo, so that will have to be addressed in a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781 Approved by: https://github.com/ezyang, https://github.com/awgu
Author
Committer
Parents
Loading