pytorch
48056b16 - [FSDP] Reshard frozen params in backward (#101982)

Commit
1 year ago
[FSDP] Reshard frozen params in backward (#101982) This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass. - Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass. - The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard). ~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading