pytorch
0a98d943 - [FSDP] Auto-pad for no `pad()` in post-bwd hook (`use_orig_params=True`) (#99054)

Commit
1 year ago
[FSDP] Auto-pad for no `pad()` in post-bwd hook (`use_orig_params=True`) (#99054) This avoids the post-backward `F.pad()` call before reduce-scatter for `use_orig_params=True`. It is pretty cool that we built up all of the necessary infra in past PRs so that this change is simple. We simply append one more padding tensor to pad out the `FlatParameter` numel to be divisible by the world size. This causes the flat gradient to be computed directly with the padded size, removing the need for the explicit `F.pad()` call. Because the infra is built out right now for `use_orig_params=True`, we only add this auto-pad logic for that path. We can add it for `use_orig_params=False` if needed in follow-up work. I confirmed in local tests that this removes the pad call. Before (yes `aten::pad`): ![Screenshot 2023-04-13 at 1 38 21 PM](https://user-images.githubusercontent.com/31054793/231840432-e0875972-6546-4cf1-aaaa-bc3949050519.png) After (no `aten::pad`): ![Screenshot 2023-04-13 at 1 38 29 PM](https://user-images.githubusercontent.com/31054793/231840422-8dd6f5ab-0a7a-4393-a835-42009948eb62.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99054 Approved by: https://github.com/fegin, https://github.com/zhaojuanmao
Author
Committer
Parents
Loading