pytorch
6bf2776a - [FSDP][Perf] Do not call `pad` in no-padding case (#88769)

Commit
2 years ago
[FSDP][Perf] Do not call `pad` in no-padding case (#88769) - Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case. - This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88769 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading